/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.context.rating;

import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Maths;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.SequentialSparseVector;
import net.librec.math.structure.Vector;
import net.librec.recommender.SocialRecommender;

@ModelData(value={"isRating", "rste", "userFactors", "itemFactors", "userSocialRatio", "socialMatrix"})
public class RSTERecommender
extends SocialRecommender {
    private float userSocialRatio;

    @Override
    public void setup() throws LibrecException {
        super.setup();
        this.userFactors.init(1.0);
        this.itemFactors.init(1.0);
        this.userSocialRatio = this.conf.getFloat("rec.user.social.ratio", Float.valueOf(0.8f)).floatValue();
    }

    @Override
    protected void trainModel() throws LibrecException {
        for (int iter = 1; iter <= this.numIterations; ++iter) {
            this.loss = 0.0;
            DenseMatrix tempUserFactors = new DenseMatrix(this.numUsers, this.numFactors);
            DenseMatrix tempItemFactors = new DenseMatrix(this.numItems, this.numFactors);
            for (int userIdx = 0; userIdx < this.numUsers; ++userIdx) {
                SequentialSparseVector userSoicalValues = this.socialMatrix.row(userIdx);
                double weightSocialSum = 0.0;
                for (Vector.VectorEntry ve : userSoicalValues) {
                    double socialValue = ve.get();
                    weightSocialSum += socialValue;
                }
                double[] sumUserSocialFactor = new double[this.numFactors];
                for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                    for (Vector.VectorEntry ve : userSoicalValues) {
                        int userSocialIdx = ve.index();
                        double socialValue = ve.get();
                        int n = factorIdx;
                        sumUserSocialFactor[n] = sumUserSocialFactor[n] + socialValue * this.userFactors.get(userSocialIdx, factorIdx);
                    }
                }
                for (Vector.VectorEntry vectorEntry : this.trainMatrix.row(userIdx)) {
                    int itemIdx = vectorEntry.index();
                    double rating = vectorEntry.get();
                    double norRating = Maths.normalize(rating, this.minRate, this.maxRate);
                    double predictRating = this.userFactors.row(userIdx).dot(this.itemFactors.row(itemIdx));
                    double sum2 = 0.0;
                    for (Vector.VectorEntry ve : userSoicalValues) {
                        int userSocialIdx = ve.index();
                        double socialValue = ve.get();
                        sum2 += socialValue * this.userFactors.row(userSocialIdx).dot(this.itemFactors.row(itemIdx));
                    }
                    double socialPredictRating = weightSocialSum > 0.0 ? sum2 / weightSocialSum : 0.0;
                    double finalPredictRating = (double)this.userSocialRatio * predictRating + (double)(1.0f - this.userSocialRatio) * socialPredictRating;
                    double error = Maths.logistic(finalPredictRating) - norRating;
                    this.loss += error * error;
                    double deriValue = Maths.logisticGradientValue(finalPredictRating) * error;
                    for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                        double userFactorValue = this.userFactors.get(userIdx, factorIdx);
                        double itemFactorValue = this.itemFactors.get(itemIdx, factorIdx);
                        double userDeriValue = (double)this.userSocialRatio * deriValue * itemFactorValue + (double)this.regUser * userFactorValue;
                        double userSocialFactorValue = weightSocialSum > 0.0 ? sumUserSocialFactor[factorIdx] / weightSocialSum : 0.0;
                        double itemDeriValue = deriValue * ((double)this.userSocialRatio * userFactorValue + (double)(1.0f - this.userSocialRatio) * userSocialFactorValue) + (double)this.regItem * itemFactorValue;
                        tempUserFactors.plus(userIdx, factorIdx, userDeriValue);
                        tempItemFactors.plus(itemIdx, factorIdx, itemDeriValue);
                        this.loss += (double)this.regUser * userFactorValue * userFactorValue + (double)this.regItem * itemFactorValue * itemFactorValue;
                    }
                }
            }
            for (int userSocialIdx = 0; userSocialIdx < this.numUsers; ++userSocialIdx) {
                SequentialSparseVector socialUserValues = this.socialMatrix.column(userSocialIdx);
                for (Vector.VectorEntry ve_1 : socialUserValues) {
                    int socialUserIdx = ve_1.index();
                    double socialUserValue = ve_1.get();
                    SequentialSparseVector socialItemValues = this.trainMatrix.row(socialUserIdx);
                    SequentialSparseVector socialUserSoicalValues = this.socialMatrix.row(socialUserIdx);
                    int[] socialUserSocialIndices = socialUserSoicalValues.getIndices();
                    for (Vector.VectorEntry ve_2 : socialItemValues) {
                        int socialItemIdx = ve_2.index();
                        double socialItemValue = ve_2.get();
                        double predictRating = this.userFactors.row(socialUserIdx).dot(this.itemFactors.row(socialItemIdx));
                        double sum3 = 0.0;
                        double socialWeightSum = 0.0;
                        for (Vector.VectorEntry ve_3 : socialUserSoicalValues) {
                            int socialUserSocialIdx = ve_3.index();
                            double socialUserSocialValue = ve_3.get();
                            sum3 += socialUserSocialValue * this.userFactors.row(socialUserSocialIdx).dot(this.itemFactors.row(socialItemIdx));
                            socialWeightSum += socialUserSocialValue;
                        }
                        double socialPredictRating = socialWeightSum > 0.0 ? sum3 / socialWeightSum : 0.0;
                        double finalPredictRating = (double)this.userSocialRatio * predictRating + (double)(1.0f - this.userSocialRatio) * socialPredictRating;
                        double error = Maths.logistic(finalPredictRating) - Maths.normalize(socialItemValue, this.minRate, this.maxRate);
                        double deriValue = Maths.logisticGradientValue(finalPredictRating) * error * socialUserValue;
                        for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                            tempUserFactors.plus(userSocialIdx, factorIdx, (double)(1.0f - this.userSocialRatio) * deriValue * this.itemFactors.get(socialItemIdx, factorIdx));
                        }
                    }
                }
            }
            this.userFactors = this.userFactors.plus(tempUserFactors.times(-this.learnRate));
            this.itemFactors = this.itemFactors.plus(tempItemFactors.times(-this.learnRate));
            this.loss *= 0.5;
            if (this.isConverged(iter) && this.earlyStop) break;
            this.updateLRate(iter);
        }
    }

    @Override
    protected double predict(int userIdx, int itemIdx) {
        double predictRating = this.userFactors.row(userIdx).dot(this.itemFactors.row(itemIdx));
        double sum2 = 0.0;
        double socialWeightSum = 0.0;
        SequentialSparseVector userSocialVector = this.socialMatrix.row(userIdx);
        for (Vector.VectorEntry ve : userSocialVector) {
            int userSoicalIdx = ve.index();
            double userSocialValue = ve.get();
            sum2 += userSocialValue * this.userFactors.row(userSoicalIdx).dot(this.itemFactors.row(itemIdx));
            socialWeightSum += userSocialValue;
        }
        double soicalPredictRatting = socialWeightSum > 0.0 ? sum2 / socialWeightSum : 0.0;
        predictRating = (double)this.userSocialRatio * predictRating + (double)(1.0f - this.userSocialRatio) * soicalPredictRatting;
        return predictRating;
    }
}

