/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.cf.rating;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import net.librec.common.LibrecException;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.MatrixEntry;
import net.librec.math.structure.SequentialAccessSparseMatrix;
import net.librec.recommender.cf.rating.BiasedMFRecommender;
import org.apache.commons.lang.ArrayUtils;

public class ASVDPlusPlusRecommender
extends BiasedMFRecommender {
    protected DenseMatrix impItemFactors;
    protected DenseMatrix neiItemFactors;
    protected List<List<Integer>> userItemsList;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        this.impItemFactors = new DenseMatrix(this.numItems, this.numFactors);
        this.impItemFactors.init(this.initMean, this.initStd);
        this.neiItemFactors = new DenseMatrix(this.numItems, this.numFactors);
        this.neiItemFactors.init(this.initMean, this.initStd);
        this.userItemsList = this.getUserItemsList(this.trainMatrix);
    }

    @Override
    protected void trainModel() throws LibrecException {
        for (int iter = 1; iter <= this.numIterations; ++iter) {
            this.loss = 0.0;
            for (MatrixEntry matrixEntry : this.trainMatrix) {
                int factorIdx;
                int userIdx = matrixEntry.row();
                int itemIdx = matrixEntry.column();
                double realRating = matrixEntry.get();
                double predictRating = this.predict(userIdx, itemIdx);
                double error = realRating - predictRating;
                List<Integer> items = this.userItemsList.get(userIdx);
                double impNor = Math.sqrt(items.size());
                double userBiasValue = this.userBiases.get(userIdx);
                this.userBiases.plus(userIdx, (double)this.learnRate * (error - this.regBias * userBiasValue));
                double itemBiasValue = this.itemBiases.get(itemIdx);
                this.itemBiases.plus(itemIdx, (double)this.learnRate * (error - this.regBias * itemBiasValue));
                double[] sumImpItemsFactors = new double[this.numFactors];
                double[] sumNeiItemsFactors = new double[this.numFactors];
                for (factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                    double sumImpItemsFactor = 0.0;
                    double sumNeiItemsFactor = 0.0;
                    for (int ItemIdx : items) {
                        sumImpItemsFactor += this.impItemFactors.get(ItemIdx, factorIdx);
                        sumNeiItemsFactor += this.neiItemFactors.get(ItemIdx, factorIdx) * (realRating - this.globalMean - this.userBiases.get(userIdx) - this.itemBiases.get(ItemIdx));
                    }
                    sumImpItemsFactors[factorIdx] = impNor > 0.0 ? sumImpItemsFactor / impNor : sumImpItemsFactor;
                    sumNeiItemsFactors[factorIdx] = impNor > 0.0 ? sumNeiItemsFactor / impNor : sumNeiItemsFactor;
                }
                for (factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                    double userFactorIdx = this.userFactors.get(userIdx, factorIdx);
                    double itemFactorIdx = this.itemFactors.get(itemIdx, factorIdx);
                    double sgd_user = error * itemFactorIdx - (double)this.regUser * userFactorIdx;
                    double sgd_item = error * (userFactorIdx + sumImpItemsFactors[factorIdx] + sumNeiItemsFactors[factorIdx]) - (double)this.regItem * itemFactorIdx;
                    this.userFactors.plus(userIdx, factorIdx, (double)this.learnRate * sgd_user);
                    this.itemFactors.plus(itemIdx, factorIdx, (double)this.learnRate * sgd_item);
                    for (int ImpitemIdx : items) {
                        double impItemFactorIdx = this.impItemFactors.get(ImpitemIdx, factorIdx);
                        double neiItemFactorIdx = this.neiItemFactors.get(ImpitemIdx, factorIdx);
                        double delta_impItem = error * itemFactorIdx / impNor - (double)this.regUser * impItemFactorIdx;
                        double delta_neiItem = error * itemFactorIdx * (realRating - this.globalMean - this.userBiases.get(userIdx) - this.itemBiases.get(ImpitemIdx)) / impNor - (double)this.regUser * neiItemFactorIdx;
                        this.impItemFactors.plus(ImpitemIdx, factorIdx, (double)this.learnRate * delta_impItem);
                        this.neiItemFactors.plus(ImpitemIdx, factorIdx, (double)this.learnRate * delta_neiItem);
                    }
                }
            }
        }
    }

    @Override
    protected double predict(int userIdx, int itemIdx) throws LibrecException {
        double predictRating = this.globalMean + this.userBiases.get(userIdx) + this.itemBiases.get(itemIdx) + super.predict(userIdx, itemIdx);
        HashMap<Integer, Integer> itemHashMap = new HashMap<Integer, Integer>();
        int[] itemIndices = this.trainMatrix.row(userIdx).getIndices();
        for (int i = 0; i < itemIndices.length; ++i) {
            itemHashMap.put(itemIndices[i], i);
        }
        List<Integer> items = this.userItemsList.get(userIdx);
        double w = Math.sqrt(items.size());
        for (int k : items) {
            predictRating += this.impItemFactors.row(k).dot(this.itemFactors.row(itemIdx)) / w;
            predictRating += this.neiItemFactors.row(k).times(this.trainMatrix.row(userIdx).getAtPosition((Integer)itemHashMap.get(k)) - this.globalMean - this.userBiases.get(userIdx) - this.itemBiases.get(k)).dot(this.itemFactors.row(itemIdx)) / w;
        }
        return predictRating;
    }

    private List<List<Integer>> getUserItemsList(SequentialAccessSparseMatrix sparseMatrix) {
        ArrayList<List<Integer>> userItemsList = new ArrayList<List<Integer>>();
        for (int userIdx = 0; userIdx < this.numUsers; ++userIdx) {
            int[] itemIndexes = this.trainMatrix.row(userIdx).getIndices();
            Integer[] inputBoxed = ArrayUtils.toObject(itemIndexes);
            List<Integer> itemList = Arrays.asList(inputBoxed);
            userItemsList.add(itemList);
        }
        return userItemsList;
    }
}

