/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.ext;

import net.librec.common.LibrecException;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.SequentialSparseVector;
import net.librec.math.structure.Vector;
import net.librec.recommender.MatrixRecommender;

public class BipolarSlopeOneRecommender
extends MatrixRecommender {
    private DenseMatrix likeDevMatrix;
    private DenseMatrix dislikeDevMatrix;
    private DenseMatrix likeCardMatrix;
    private DenseMatrix dislikeCardMatrix;
    private int[] averageRating;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        this.likeDevMatrix = new DenseMatrix(this.numItems, this.numItems);
        this.likeCardMatrix = new DenseMatrix(this.numItems, this.numItems);
        this.dislikeDevMatrix = new DenseMatrix(this.numItems, this.numItems);
        this.dislikeCardMatrix = new DenseMatrix(this.numItems, this.numItems);
        this.averageRating = new int[this.numUsers];
    }

    @Override
    protected void trainModel() throws LibrecException {
        for (int userIdx = 0; userIdx < this.numUsers; ++userIdx) {
            SequentialSparseVector itemRatingsVector = this.trainMatrix.row(userIdx);
            this.averageRating[userIdx] = (int)itemRatingsVector.mean();
            for (Vector.VectorEntry itemIdxRating : itemRatingsVector) {
                int itemIdx = itemIdxRating.index();
                double userItemRating = itemIdxRating.get();
                for (Vector.VectorEntry comparedItemIdxRating : itemRatingsVector) {
                    int comparedItemIdx = comparedItemIdxRating.index();
                    if (itemIdx == comparedItemIdx) continue;
                    double comparedRating = comparedItemIdxRating.get();
                    if (userItemRating >= (double)this.averageRating[userIdx] && comparedRating >= (double)this.averageRating[userIdx]) {
                        this.likeDevMatrix.set(itemIdx, comparedItemIdx, userItemRating - comparedRating);
                        this.likeCardMatrix.set(itemIdx, comparedItemIdx, 1.0);
                        continue;
                    }
                    if (!(userItemRating < (double)this.averageRating[userIdx]) || !(comparedRating < (double)this.averageRating[userIdx])) continue;
                    this.dislikeDevMatrix.set(itemIdx, comparedItemIdx, userItemRating - comparedRating);
                    this.dislikeCardMatrix.set(itemIdx, comparedItemIdx, 1.0);
                }
            }
        }
        for (int itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
            for (int comparedItemIdx = 0; comparedItemIdx < this.numItems; ++comparedItemIdx) {
                double sum2;
                double card = this.likeCardMatrix.get(itemIdx, comparedItemIdx);
                if (card > 0.0) {
                    sum2 = this.likeDevMatrix.get(itemIdx, comparedItemIdx);
                    this.likeDevMatrix.set(itemIdx, comparedItemIdx, sum2 / card);
                }
                if (!((card = this.dislikeCardMatrix.get(itemIdx, comparedItemIdx)) > 0.0)) continue;
                sum2 = this.dislikeDevMatrix.get(itemIdx, comparedItemIdx);
                this.dislikeCardMatrix.set(itemIdx, comparedItemIdx, sum2 / card);
            }
        }
    }

    @Override
    protected double predict(int userIdx, int itemIdx) throws LibrecException {
        SequentialSparseVector itemRatingsVector = this.trainMatrix.row(userIdx);
        double predictRatings = 0.0;
        double cardinaryValues = 0.0;
        for (Vector.VectorEntry comparedItemIdxRating : itemRatingsVector) {
            int comparedItemIdx = comparedItemIdxRating.index();
            if (comparedItemIdx == itemIdx) continue;
            double comparedRating = comparedItemIdxRating.get();
            double cardinaryValue = this.likeCardMatrix.get(itemIdx, comparedItemIdx);
            if (cardinaryValue > 0.0) {
                predictRatings += (this.likeDevMatrix.get(itemIdx, comparedItemIdx) + comparedRating) * cardinaryValue;
                cardinaryValues += cardinaryValue;
            }
            if (!((cardinaryValue = this.dislikeCardMatrix.get(itemIdx, comparedItemIdx)) > 0.0)) continue;
            predictRatings += (this.dislikeDevMatrix.get(itemIdx, comparedItemIdx) + comparedRating) * cardinaryValue;
            cardinaryValues += cardinaryValue;
        }
        return cardinaryValues > 0.0 ? predictRatings / cardinaryValues : this.globalMean;
    }
}

