/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.cf.ranking;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Maths;
import net.librec.math.structure.SequentialAccessSparseMatrix;
import net.librec.recommender.MatrixFactorizationRecommender;
import org.apache.commons.lang.ArrayUtils;

@ModelData(value={"isRanking", "climf", "userFactors", "itemFactors"})
public class CLIMFRecommender
extends MatrixFactorizationRecommender {
    private List<Set<Integer>> userItemsSet;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
    }

    @Override
    protected void trainModel() throws LibrecException {
        this.userItemsSet = this.getUserItemsSet(this.trainMatrix);
        for (int iter = 1; iter <= this.numIterations; ++iter) {
            this.loss = 0.0;
            for (int userIdx = 0; userIdx < this.numUsers; ++userIdx) {
                Set<Integer> itemSet = this.userItemsSet.get(userIdx);
                double[] sgds = new double[this.numFactors];
                for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                    double sgd = (double)(-this.regUser) * this.userFactors.get(userIdx, factorIdx);
                    for (int itemIdx : itemSet) {
                        double predictValue = this.predict(userIdx, itemIdx);
                        double itemFactorValue = this.itemFactors.get(itemIdx, factorIdx);
                        sgd += Maths.logistic(-predictValue) * itemFactorValue;
                        for (int compareItemIdx : itemSet) {
                            if (compareItemIdx == itemIdx) continue;
                            double compPredictValue = this.predict(userIdx, compareItemIdx);
                            double compItemFactorValue = this.itemFactors.get(compareItemIdx, factorIdx);
                            double diffValue = compPredictValue - predictValue;
                            sgd += Maths.logisticGradientValue(diffValue) / (1.0 - Maths.logistic(diffValue)) * (itemFactorValue - compItemFactorValue);
                        }
                    }
                    sgds[factorIdx] = sgd;
                }
                HashMap itemsSgds = new HashMap();
                for (int itemIdx : itemSet) {
                    double predictValue = this.predict(userIdx, itemIdx);
                    ArrayList<Double> itemSgds = new ArrayList<Double>();
                    for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                        double userFactorValue = this.userFactors.get(userIdx, factorIdx);
                        double itemFactorValue = this.itemFactors.get(itemIdx, factorIdx);
                        double judgeValue = 1.0;
                        double sgd = judgeValue * Maths.logistic(-predictValue) * userFactorValue - (double)this.regItem * itemFactorValue;
                        for (int compItemIdx : itemSet) {
                            if (compItemIdx == itemIdx) continue;
                            double compPredictValue = this.predict(userIdx, compItemIdx);
                            double diffValue = compPredictValue - predictValue;
                            sgd += Maths.logisticGradientValue(-diffValue) * (1.0 / (1.0 - Maths.logistic(diffValue)) - 1.0 / (1.0 - Maths.logistic(-diffValue))) * userFactorValue;
                        }
                        itemSgds.add(sgd);
                    }
                    itemsSgds.put(itemIdx, itemSgds);
                }
                for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                    this.userFactors.plus(userIdx, factorIdx, (double)this.learnRate * sgds[factorIdx]);
                }
                for (int itemIdx : itemSet) {
                    List itemSgds = (List)itemsSgds.get(itemIdx);
                    for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                        this.itemFactors.plus(itemIdx, factorIdx, (double)this.learnRate * (Double)itemSgds.get(factorIdx));
                    }
                }
                for (int itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
                    if (itemSet.contains(itemIdx)) {
                        double predictValue = this.predict(userIdx, itemIdx);
                        this.loss += Math.log(Maths.logistic(predictValue));
                        for (int compItemIdx : itemSet) {
                            double compPredictValue = this.predict(userIdx, compItemIdx);
                            this.loss += Math.log(1.0 - Maths.logistic(compPredictValue - predictValue));
                        }
                    }
                    for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                        double userFactorValue = this.userFactors.get(userIdx, factorIdx);
                        double itemFactorValue = this.itemFactors.get(itemIdx, factorIdx);
                        this.loss += -0.5 * ((double)this.regUser * userFactorValue * userFactorValue + (double)this.regItem * itemFactorValue * itemFactorValue);
                    }
                }
            }
            if (this.isConverged(iter) && this.earlyStop) break;
            this.updateLRate(iter);
        }
    }

    private List<Set<Integer>> getUserItemsSet(SequentialAccessSparseMatrix sparseMatrix) {
        ArrayList<Set<Integer>> userItemsSet = new ArrayList<Set<Integer>>();
        for (int userIdx = 0; userIdx < this.numUsers; ++userIdx) {
            int[] itemIndexes = sparseMatrix.row(userIdx).getIndices();
            Integer[] inputBoxed = ArrayUtils.toObject(itemIndexes);
            List<Integer> itemList = Arrays.asList(inputBoxed);
            userItemsSet.add(new HashSet<Integer>(itemList));
        }
        return userItemsSet;
    }
}

