/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.context.rating;

import com.google.common.cache.LoadingCache;
import java.util.List;
import java.util.concurrent.ExecutionException;
import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.MatrixEntry;
import net.librec.math.structure.VectorBasedDenseVector;
import net.librec.recommender.SocialRecommender;

@ModelData(value={"isRating", "trustsvd", "userFactors", "itemFactors", "impItemFactors", "userBiases", "itemBiases", "socialMatrix", "trainMatrix"})
public class TrustSVDRecommender
extends SocialRecommender {
    private DenseMatrix impItemFactors;
    private DenseMatrix trusteeFactors;
    private VectorBasedDenseVector trusteeWeights;
    private VectorBasedDenseVector trusterWeights;
    private VectorBasedDenseVector impItemWeights;
    private VectorBasedDenseVector userBiases;
    private VectorBasedDenseVector itemBiases;
    protected double regBias;
    protected LoadingCache<Integer, List<Integer>> userItemsCache;
    protected LoadingCache<Integer, List<Integer>> userTrusteeCache;
    protected static String cacheSpec;

    @Override
    public void setup() throws LibrecException {
        super.setup();
        this.regBias = this.conf.getDouble("rec.bias.regularization", 0.01);
        cacheSpec = this.conf.get("guava.cache.spec", "maximumSize=200,expireAfterAccess=2m");
        this.userBiases = new VectorBasedDenseVector(this.numUsers);
        this.itemBiases = new VectorBasedDenseVector(this.numItems);
        this.userBiases.init(this.initMean, this.initStd);
        this.itemBiases.init(this.initMean, this.initStd);
        this.trusteeFactors = new DenseMatrix(this.numUsers, this.numFactors);
        this.impItemFactors = new DenseMatrix(this.numItems, this.numFactors);
        this.trusteeFactors.init(this.initMean, this.initStd);
        this.impItemFactors.init(this.initMean, this.initStd);
        this.trusteeWeights = new VectorBasedDenseVector(this.numUsers);
        this.trusterWeights = new VectorBasedDenseVector(this.numUsers);
        this.impItemWeights = new VectorBasedDenseVector(this.numItems);
        for (int userIdx = 0; userIdx < this.numUsers; ++userIdx) {
            int userFriendCount = this.socialMatrix.column(userIdx).size();
            this.trusteeWeights.set(userIdx, userFriendCount > 0 ? 1.0 / Math.sqrt(userFriendCount) : 1.0);
            userFriendCount = this.socialMatrix.row(userIdx).size();
            this.trusterWeights.set(userIdx, userFriendCount > 0 ? 1.0 / Math.sqrt(userFriendCount) : 1.0);
        }
        for (int itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
            int itemUsersCount = this.trainMatrix.column(itemIdx).size();
            this.impItemWeights.set(itemIdx, itemUsersCount > 0 ? 1.0 / Math.sqrt(itemUsersCount) : 1.0);
        }
        this.userItemsCache = this.trainMatrix.rowColumnsCache(cacheSpec);
        this.userTrusteeCache = this.socialMatrix.rowColumnsCache(cacheSpec);
    }

    @Override
    protected void trainModel() throws LibrecException {
        for (int iter = 1; iter <= this.numIterations; ++iter) {
            int userIdx;
            this.loss = 0.0;
            DenseMatrix tempUserFactors = new DenseMatrix(this.numUsers, this.numFactors);
            DenseMatrix trusteeTempFactors = new DenseMatrix(this.numUsers, this.numFactors);
            for (MatrixEntry matrixEntry : this.trainMatrix) {
                int factorIdx;
                userIdx = matrixEntry.row();
                int itemIdx = matrixEntry.column();
                double realRating = matrixEntry.get();
                double userBiasValue = this.userBiases.get(userIdx);
                double itemBiasValue = this.itemBiases.get(itemIdx);
                double predictRating = this.globalMean + userBiasValue + itemBiasValue + this.userFactors.row(userIdx).dot(this.itemFactors.row(itemIdx));
                List<Integer> impItemsList = null;
                try {
                    impItemsList = this.userItemsCache.get(userIdx);
                }
                catch (ExecutionException e) {
                    e.printStackTrace();
                }
                if (impItemsList.size() > 0) {
                    double sum2 = 0.0;
                    for (int impItemIdx : impItemsList) {
                        sum2 += this.impItemFactors.row(impItemIdx).dot(this.itemFactors.row(itemIdx));
                    }
                    predictRating += sum2 / Math.sqrt(impItemsList.size());
                }
                List<Integer> trusteesList = null;
                try {
                    trusteesList = this.userTrusteeCache.get(userIdx);
                }
                catch (ExecutionException e) {
                    e.printStackTrace();
                }
                if (trusteesList.size() > 0) {
                    double sum3 = 0.0;
                    for (int trusteeIdx : trusteesList) {
                        sum3 += this.trusteeFactors.row(trusteeIdx).dot(this.itemFactors.row(itemIdx));
                    }
                    predictRating += sum3 / Math.sqrt(trusteesList.size());
                }
                double error = predictRating - realRating;
                this.loss += error * error;
                double userWeightDenom = Math.sqrt(impItemsList.size());
                double trusteeWeightDenom = Math.sqrt(trusteesList.size());
                double userWeight = 1.0 / userWeightDenom;
                double itemWeight = this.impItemWeights.get(itemIdx);
                double sgd = error + this.regBias * userWeight * userBiasValue;
                this.userBiases.plus(userIdx, (double)(-this.learnRate) * sgd);
                sgd = error + this.regBias * itemWeight * itemBiasValue;
                this.itemBiases.plus(itemIdx, (double)(-this.learnRate) * sgd);
                this.loss += this.regBias * userWeight * userBiasValue * userBiasValue + this.regBias * itemWeight * itemBiasValue * itemBiasValue;
                double[] sumImpItemsFactors = new double[this.numFactors];
                for (int factorIdx2 = 0; factorIdx2 < this.numFactors; ++factorIdx2) {
                    double sum4 = 0.0;
                    for (int impItemIdx : impItemsList) {
                        sum4 += this.impItemFactors.get(impItemIdx, factorIdx2);
                    }
                    sumImpItemsFactors[factorIdx2] = userWeightDenom > 0.0 ? sum4 / userWeightDenom : sum4;
                }
                double[] sumTrusteesFactors = new double[this.numFactors];
                for (factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                    double sum5 = 0.0;
                    for (int trusteeIdx : trusteesList) {
                        sum5 += this.trusteeFactors.get(trusteeIdx, factorIdx);
                    }
                    sumTrusteesFactors[factorIdx] = trusteeWeightDenom > 0.0 ? sum5 / trusteeWeightDenom : sum5;
                }
                for (factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                    double userFactorValue = this.userFactors.get(userIdx, factorIdx);
                    double itemFactorValue = this.itemFactors.get(itemIdx, factorIdx);
                    double deltaUser = error * itemFactorValue + (double)this.regUser * userWeight * userFactorValue;
                    double deltaItem = error * (userFactorValue + sumImpItemsFactors[factorIdx] + sumTrusteesFactors[factorIdx]) + (double)this.regItem * itemWeight * itemFactorValue;
                    tempUserFactors.plus(userIdx, factorIdx, deltaUser);
                    this.itemFactors.plus(itemIdx, factorIdx, (double)(-this.learnRate) * deltaItem);
                    this.loss += (double)this.regUser * userWeight * userFactorValue * userFactorValue + (double)this.regItem * itemWeight * itemFactorValue * itemFactorValue;
                    for (int impItemIdx : impItemsList) {
                        double impItemFactorValue = this.impItemFactors.get(impItemIdx, factorIdx);
                        double impItemWeightValue = this.impItemWeights.get(impItemIdx);
                        double deltaImpItem = error * itemFactorValue / userWeightDenom + (double)this.regItem * impItemWeightValue * impItemFactorValue;
                        this.impItemFactors.plus(impItemIdx, factorIdx, (double)(-this.learnRate) * deltaImpItem);
                        this.loss += (double)this.regItem * impItemWeightValue * impItemFactorValue * impItemFactorValue;
                    }
                    for (int trusteeIdx : trusteesList) {
                        double trusteeFactorValue = this.trusteeFactors.get(trusteeIdx, factorIdx);
                        double trusteeWeightValue = this.trusteeWeights.get(trusteeIdx);
                        double deltaTrustee = error * itemFactorValue / trusteeWeightDenom + (double)this.regUser * trusteeWeightValue * trusteeFactorValue;
                        trusteeTempFactors.plus(trusteeIdx, factorIdx, deltaTrustee);
                        this.loss += (double)this.regUser * trusteeWeightValue * trusteeFactorValue * trusteeFactorValue;
                    }
                }
            }
            for (MatrixEntry socialMatrixEntry : this.socialMatrix) {
                userIdx = socialMatrixEntry.row();
                int trusteeIdx = socialMatrixEntry.column();
                double socialValue = socialMatrixEntry.get();
                if (socialValue == 0.0) continue;
                double predtictSocialValue = this.userFactors.row(userIdx).dot(this.trusteeFactors.row(trusteeIdx));
                double socialError = predtictSocialValue - socialValue;
                this.loss += (double)this.regSocial * socialError * socialError;
                double deriValue = (double)this.regSocial * socialError;
                double trusterWeightValue = this.trusterWeights.get(userIdx);
                for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                    double userFactorValue = this.userFactors.get(userIdx, factorIdx);
                    double trusteeFactorValue = this.trusteeFactors.get(trusteeIdx, factorIdx);
                    tempUserFactors.plus(userIdx, factorIdx, deriValue * trusteeFactorValue + (double)this.regSocial * trusterWeightValue * userFactorValue);
                    trusteeTempFactors.plus(trusteeIdx, factorIdx, deriValue * userFactorValue);
                    this.loss += (double)this.regSocial * trusterWeightValue * userFactorValue * userFactorValue;
                }
            }
            this.userFactors = this.userFactors.plus(tempUserFactors.times(-this.learnRate));
            this.trusteeFactors = this.trusteeFactors.plus(trusteeTempFactors.times(-this.learnRate));
            this.loss *= 0.5;
            if (this.isConverged(iter) && this.earlyStop) break;
            this.updateLRate(iter);
        }
    }

    @Override
    protected double predict(int userIdx, int itemIdx) throws LibrecException {
        double predictRating = this.globalMean + this.userBiases.get(userIdx) + this.itemBiases.get(itemIdx) + this.userFactors.row(userIdx).dot(this.itemFactors.row(itemIdx));
        List<Integer> userItemsList = null;
        try {
            userItemsList = this.userItemsCache.get(userIdx);
        }
        catch (ExecutionException e) {
            e.printStackTrace();
        }
        if (userItemsList.size() > 0) {
            double sum2 = 0.0;
            for (int userItemIdx : userItemsList) {
                sum2 += this.impItemFactors.row(userItemIdx).dot(this.itemFactors.row(itemIdx));
            }
            predictRating += sum2 / Math.sqrt(userItemsList.size());
        }
        List<Integer> trusteeList = null;
        try {
            trusteeList = this.userTrusteeCache.get(userIdx);
        }
        catch (ExecutionException e) {
            e.printStackTrace();
        }
        if (trusteeList.size() > 0) {
            double sum3 = 0.0;
            for (int trusteeIdx : trusteeList) {
                sum3 += this.trusteeFactors.row(trusteeIdx).dot(this.itemFactors.row(itemIdx));
            }
            predictRating += sum3 / Math.sqrt(trusteeList.size());
        }
        return predictRating;
    }

    @Override
    protected double predict(int userIdx, int itemIdx, boolean bounded) throws LibrecException {
        double predictRating = this.predict(userIdx, itemIdx);
        return predictRating;
    }
}

