/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.cf.ranking;

import com.google.common.cache.LoadingCache;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Maths;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.VectorBasedDenseVector;
import net.librec.recommender.MatrixFactorizationRecommender;

@ModelData(value={"isRanking", "gbpr", "userFactors", "itemFactors", "trainMatrix"})
public class GBPRRecommender
extends MatrixFactorizationRecommender {
    private float rho;
    private int gLen;
    protected double regBias;
    private VectorBasedDenseVector itemBiases;
    protected LoadingCache<Integer, List<Integer>> userItemsCache;
    protected LoadingCache<Integer, List<Integer>> itemUsersCache;
    protected static String cacheSpec;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        this.itemBiases = new VectorBasedDenseVector(this.numItems);
        this.itemBiases.init();
        this.rho = this.conf.getFloat("rec.gpbr.rho", Float.valueOf(1.5f)).floatValue();
        this.gLen = this.conf.getInt("rec.gpbr.gsize", 2);
        this.regBias = this.conf.getDouble("rec.bias.regularization", 0.01);
        cacheSpec = this.conf.get("guava.cache.spec", "maximumSize=200,expireAfterAccess=2m");
        this.userItemsCache = this.trainMatrix.rowColumnsCache(cacheSpec);
        this.itemUsersCache = this.trainMatrix.columnRowsCache(cacheSpec);
    }

    @Override
    protected void trainModel() throws LibrecException {
        int maxSample = this.trainMatrix.size();
        for (int iter = 1; iter <= this.numIterations; ++iter) {
            this.loss = 0.0;
            DenseMatrix tempUserFactors = new DenseMatrix(this.numUsers, this.numFactors);
            DenseMatrix tempItemFactors = new DenseMatrix(this.numItems, this.numFactors);
            for (int sample2 = 0; sample2 < maxSample; ++sample2) {
                int negItemIdx;
                int userIdx;
                List<Integer> ratedItems = null;
                do {
                    userIdx = Randoms.uniform(this.numUsers);
                    try {
                        ratedItems = this.userItemsCache.get(userIdx);
                    }
                    catch (ExecutionException e) {
                        e.printStackTrace();
                    }
                } while (ratedItems.size() == 0);
                int posItemIdx = Randoms.random(ratedItems);
                List<Integer> posRatedUserList = null;
                try {
                    posRatedUserList = this.itemUsersCache.get(posItemIdx);
                }
                catch (ExecutionException e) {
                    e.printStackTrace();
                }
                HashSet<Integer> groupSet = new HashSet<Integer>();
                if (posRatedUserList.size() <= this.gLen) {
                    groupSet.addAll(posRatedUserList);
                } else {
                    groupSet.add(userIdx);
                    while (groupSet.size() < this.gLen) {
                        int tempUserIdx = Randoms.random(posRatedUserList);
                        if (groupSet.contains(tempUserIdx)) continue;
                        groupSet.add(tempUserIdx);
                    }
                }
                double posPredictRating = this.predict(userIdx, posItemIdx, groupSet);
                while (ratedItems.contains(negItemIdx = Randoms.uniform(this.numItems))) {
                }
                double negPredictRating = this.predict(userIdx, negItemIdx);
                double diffValue = posPredictRating - negPredictRating;
                double lossValue = -Math.log(Maths.logistic(diffValue));
                this.loss += lossValue;
                double deriValue = Maths.logistic(-diffValue);
                double posBiasValue = this.itemBiases.get(posItemIdx);
                this.itemBiases.plus(posItemIdx, (double)this.learnRate * (deriValue - this.regBias * posBiasValue));
                this.loss += this.regBias * posBiasValue * posBiasValue;
                double negBiasValue = this.itemBiases.get(negItemIdx);
                this.itemBiases.plus(negItemIdx, (double)this.learnRate * (-deriValue - this.regBias * negBiasValue));
                this.loss += this.regBias * negBiasValue * negBiasValue;
                double averageWeight = 1.0 / (double)groupSet.size();
                double[] sumGroup = new double[this.numFactors];
                Iterator iterator = groupSet.iterator();
                while (iterator.hasNext()) {
                    int groupUserIdx = (Integer)iterator.next();
                    double delta = groupUserIdx == userIdx ? 1.0 : 0.0;
                    int factorIdx = 0;
                    while (factorIdx < this.numFactors) {
                        double groupUserFactorValue = this.userFactors.get(groupUserIdx, factorIdx);
                        double posItemFactorValue = this.itemFactors.get(posItemIdx, factorIdx);
                        double negItemFactorValue = this.itemFactors.get(negItemIdx, factorIdx);
                        double deltaGroup = (double)this.rho * averageWeight * posItemFactorValue + (double)(1.0f - this.rho) * delta * posItemFactorValue - delta * negItemFactorValue;
                        tempUserFactors.plus(groupUserIdx, factorIdx, (double)this.learnRate * (deriValue * deltaGroup - (double)this.regUser * groupUserFactorValue));
                        this.loss += (double)this.regUser * groupUserFactorValue * groupUserFactorValue;
                        int n = factorIdx++;
                        sumGroup[n] = sumGroup[n] + groupUserFactorValue;
                    }
                }
                for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                    double userFactorValue = this.userFactors.get(userIdx, factorIdx);
                    double posItemFactorValue = this.itemFactors.get(posItemIdx, factorIdx);
                    double negItemFactorValue = this.itemFactors.get(negItemIdx, factorIdx);
                    double posDelta = (double)this.rho * averageWeight * sumGroup[factorIdx] + (double)(1.0f - this.rho) * userFactorValue;
                    tempItemFactors.plus(posItemIdx, factorIdx, (double)this.learnRate * (deriValue * posDelta - (double)this.regItem * posItemFactorValue));
                    this.loss += (double)this.regItem * posItemFactorValue * posItemFactorValue;
                    this.loss += (double)this.regItem * negItemFactorValue * negItemFactorValue;
                    double negDelta = -userFactorValue;
                    tempItemFactors.plus(negItemIdx, factorIdx, (double)this.learnRate * (deriValue * negDelta - (double)this.regItem * negItemFactorValue));
                }
            }
            this.userFactors.assign(this.userFactors.plus(tempUserFactors));
            this.itemFactors.assign(this.itemFactors.plus(tempItemFactors));
            if (this.isConverged(iter) && this.earlyStop) break;
            this.updateLRate(iter);
        }
    }

    protected double predict(int userIdx, int itemIdx, Set<Integer> groupSet) throws LibrecException {
        double predictRating = this.predict(userIdx, itemIdx);
        double sum2 = 0.0;
        for (int groupUserIdx : groupSet) {
            sum2 += this.userFactors.row(groupUserIdx).dot(this.itemFactors.row(itemIdx));
        }
        double groupRating = sum2 / (double)groupSet.size() + this.itemBiases.get(itemIdx);
        return (double)this.rho * groupRating + (double)(1.0f - this.rho) * predictRating;
    }

    @Override
    protected double predict(int userIdx, int itemIdx) throws LibrecException {
        return this.itemBiases.get(itemIdx) + super.predict(userIdx, itemIdx);
    }
}

