/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.cf.ranking;

import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Maths;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.VectorBasedDenseVector;
import net.librec.recommender.MatrixFactorizationRecommender;
import net.librec.util.Lists;

@ModelData(value={"isRanking", "wbpr", "userFactors", "itemFactors", "itemBiases", "trainMatrix"})
public class WBPRRecommender
extends MatrixFactorizationRecommender {
    private LoadingCache<Integer, Set<Integer>> userItemsSet;
    private List<Map.Entry<Integer, Double>> sortedItemPops;
    private LoadingCache<Integer, List<Map.Entry<Integer, Double>>> cacheItemProbs;
    protected LoadingCache<Integer, List<Integer>> userItemsCache;
    private VectorBasedDenseVector itemBiases;
    protected float regBias;
    protected static String cacheSpec;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        this.regBias = this.conf.getFloat("rec.bias.regularization", Float.valueOf(0.01f)).floatValue();
        this.itemBiases = new VectorBasedDenseVector(this.numItems);
        this.itemBiases.init(0.01);
        cacheSpec = this.conf.get("guava.cache.spec", "maximumSize=200,expireAfterAccess=2m");
        this.userItemsCache = this.trainMatrix.rowColumnsCache(cacheSpec);
        this.userItemsSet = this.trainMatrix.rowColumnsSetCache(cacheSpec);
        this.sortedItemPops = new ArrayList<Map.Entry<Integer, Double>>();
        for (int itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
            this.sortedItemPops.add(new AbstractMap.SimpleEntry<Integer, Double>(itemIdx, Double.valueOf(this.trainMatrix.column(itemIdx).getNumEntries())));
        }
        Lists.sortList(this.sortedItemPops, true);
        this.cacheItemProbs = this.getCacheItemProbs();
    }

    @Override
    protected void trainModel() throws LibrecException {
        int maxSample = this.trainMatrix.size();
        for (int iter = 1; iter <= this.numIterations; ++iter) {
            this.loss = 0.0;
            for (int sampleCount = 0; sampleCount < maxSample; ++sampleCount) {
                int userIdx = 0;
                int posItemIdx = 0;
                int negItemIdx = 0;
                List<Integer> ratedItems = null;
                List<Map.Entry<Integer, Double>> itemProbs = null;
                do {
                    userIdx = Randoms.uniform(this.numUsers);
                    try {
                        ratedItems = this.userItemsCache.get(userIdx);
                    }
                    catch (ExecutionException e) {
                        e.printStackTrace();
                    }
                } while (ratedItems.size() == 0);
                posItemIdx = Randoms.random(ratedItems);
                try {
                    itemProbs = this.cacheItemProbs.get(userIdx);
                }
                catch (ExecutionException e) {
                    e.printStackTrace();
                }
                double rand = Randoms.random();
                double sum2 = 0.0;
                for (Map.Entry<Integer, Double> itemProb : itemProbs) {
                    if (!((sum2 += itemProb.getValue().doubleValue()) >= rand)) continue;
                    negItemIdx = itemProb.getKey();
                    break;
                }
                double posPredictRating = this.predict(userIdx, posItemIdx);
                double negPredictRating = this.predict(userIdx, negItemIdx);
                double diffValue = posPredictRating - negPredictRating;
                double lossValue = -Math.log(Maths.logistic(diffValue));
                this.loss += lossValue;
                double deriValue = Maths.logistic(-diffValue);
                double posItemBiasValue = this.itemBiases.get(posItemIdx);
                double negItemBiasValue = this.itemBiases.get(negItemIdx);
                this.itemBiases.plus(posItemIdx, (double)this.learnRate * (deriValue - (double)this.regBias * posItemBiasValue));
                this.itemBiases.plus(negItemIdx, (double)this.learnRate * (-deriValue - (double)this.regBias * negItemBiasValue));
                this.loss += (double)this.regBias * (posItemBiasValue * posItemBiasValue + negItemBiasValue * negItemBiasValue);
                for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                    double userFactorValue = this.userFactors.get(userIdx, factorIdx);
                    double posItemFactorValue = this.itemFactors.get(posItemIdx, factorIdx);
                    double negItemFactorValue = this.itemFactors.get(negItemIdx, factorIdx);
                    this.userFactors.plus(userIdx, factorIdx, (double)this.learnRate * (deriValue * (posItemFactorValue - negItemFactorValue) - (double)this.regUser * userFactorValue));
                    this.itemFactors.plus(posItemIdx, factorIdx, (double)this.learnRate * (deriValue * userFactorValue - (double)this.regItem * posItemFactorValue));
                    this.itemFactors.plus(negItemIdx, factorIdx, (double)this.learnRate * (deriValue * -userFactorValue - (double)this.regItem * negItemFactorValue));
                    this.loss += (double)this.regUser * userFactorValue * userFactorValue + (double)this.regItem * posItemFactorValue * posItemFactorValue + (double)this.regItem * negItemFactorValue * negItemFactorValue;
                }
            }
            if (this.isConverged(iter) && this.earlyStop) break;
            this.updateLRate(iter);
        }
    }

    @Override
    protected double predict(int userIdx, int itemIdx) throws LibrecException {
        return this.itemBiases.get(itemIdx) + super.predict(userIdx, itemIdx);
    }

    private LoadingCache<Integer, List<Map.Entry<Integer, Double>>> getCacheItemProbs() {
        LoadingCache<Integer, List<Map.Entry<Integer, Double>>> cache = CacheBuilder.from(cacheSpec).build(new CacheLoader<Integer, List<Map.Entry<Integer, Double>>>(){

            @Override
            public List<Map.Entry<Integer, Double>> load(Integer u) throws Exception {
                ArrayList<Map.Entry<Integer, Double>> itemProbs = new ArrayList<Map.Entry<Integer, Double>>();
                Set ratedItemsSet = (Set)WBPRRecommender.this.userItemsSet.get(u);
                double sum2 = 0.0;
                for (Map.Entry entry : WBPRRecommender.this.sortedItemPops) {
                    Integer itemIdx = (Integer)entry.getKey();
                    double popularity = (Double)entry.getValue();
                    if (ratedItemsSet.contains(itemIdx) || !(popularity > 0.0)) continue;
                    itemProbs.add(entry);
                    sum2 += popularity;
                }
                for (Map.Entry entry : itemProbs) {
                    entry.setValue((Double)entry.getValue() / sum2);
                }
                return itemProbs;
            }
        });
        return cache;
    }
}

