/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.cf.ranking;

import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Maths;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.SequentialAccessSparseMatrix;
import net.librec.recommender.MatrixFactorizationRecommender;

@ModelData(value={"isRanking", "bpr", "userFactors", "itemFactors"})
public class BPRRecommender
extends MatrixFactorizationRecommender {
    @Override
    protected void setup() throws LibrecException {
        super.setup();
    }

    @Override
    protected void trainModel() throws LibrecException {
        IntOpenHashSet[] userItemsSet = this.getUserItemsSet(this.trainMatrix);
        int maxSample = this.trainMatrix.size();
        for (int iter = 1; iter <= this.numIterations; ++iter) {
            this.loss = 0.0;
            for (int sampleCount = 0; sampleCount < maxSample; ++sampleCount) {
                int negItemIdx;
                int userIdx;
                IntOpenHashSet itemSet;
                while ((itemSet = userItemsSet[userIdx = Randoms.uniform(this.numUsers)]).size() == 0 || itemSet.size() == this.numItems) {
                }
                int[] itemIndices = this.trainMatrix.row(userIdx).getIndices();
                int posItemIdx = itemIndices[Randoms.uniform(itemIndices.length)];
                while (itemSet.contains((Object)(negItemIdx = Randoms.uniform(this.numItems)))) {
                }
                double posPredictRating = this.predict(userIdx, posItemIdx);
                double negPredictRating = this.predict(userIdx, negItemIdx);
                double diffValue = posPredictRating - negPredictRating;
                double lossValue = -Math.log(Maths.logistic(diffValue));
                this.loss += lossValue;
                double deriValue = Maths.logistic(-diffValue);
                for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                    double userFactorValue = this.userFactors.get(userIdx, factorIdx);
                    double posItemFactorValue = this.itemFactors.get(posItemIdx, factorIdx);
                    double negItemFactorValue = this.itemFactors.get(negItemIdx, factorIdx);
                    this.userFactors.plus(userIdx, factorIdx, (double)this.learnRate * (deriValue * (posItemFactorValue - negItemFactorValue) - (double)this.regUser * userFactorValue));
                    this.itemFactors.plus(posItemIdx, factorIdx, (double)this.learnRate * (deriValue * userFactorValue - (double)this.regItem * posItemFactorValue));
                    this.itemFactors.plus(negItemIdx, factorIdx, (double)this.learnRate * (deriValue * -userFactorValue - (double)this.regItem * negItemFactorValue));
                    this.loss += (double)this.regUser * userFactorValue * userFactorValue + (double)this.regItem * posItemFactorValue * posItemFactorValue + (double)this.regItem * negItemFactorValue * negItemFactorValue;
                }
            }
            if (this.isConverged(iter) && this.earlyStop) break;
            this.updateLRate(iter);
        }
    }

    private IntOpenHashSet[] getUserItemsSet(SequentialAccessSparseMatrix sparseMatrix) {
        IntOpenHashSet[] tempUserItemsSet = new IntOpenHashSet[this.numUsers];
        for (int userIdx = 0; userIdx < this.numUsers; ++userIdx) {
            int[] itemIndices = sparseMatrix.row(userIdx).getIndices();
            IntOpenHashSet itemSet = new IntOpenHashSet(itemIndices.length);
            for (int index = 0; index < itemIndices.length; ++index) {
                itemSet.add(itemIndices[index]);
            }
            tempUserItemsSet[userIdx] = itemSet;
        }
        return tempUserItemsSet;
    }
}

