/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.cf.ranking;

import it.unimi.dsi.fastutil.doubles.Double2DoubleOpenHashMap;
import java.util.Date;
import java.util.Iterator;
import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.MatrixEntry;
import net.librec.math.structure.SequentialSparseVector;
import net.librec.math.structure.Vector;
import net.librec.recommender.MatrixFactorizationRecommender;

@ModelData(value={"isRanking", "eals", "userFactors", "itemFactors", "trainMatrix"})
public class EALSRecommender
extends MatrixFactorizationRecommender {
    protected float weightCoefficient;
    private float ratio;
    private float overallWeight;
    private int WRMFJudge;
    private double[] confidences;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        this.weightCoefficient = this.conf.getFloat("rec.wrmf.weight.coefficient", Float.valueOf(4.0f)).floatValue();
        this.ratio = this.conf.getFloat("rec.eals.ratio", Float.valueOf(0.4f)).floatValue();
        this.overallWeight = this.conf.getFloat("rec.eals.overall", Float.valueOf(128.0f)).floatValue();
        this.WRMFJudge = this.conf.getInt("rec.eals.wrmf.judge", 1);
        this.confidences = new double[this.numItems];
        this.initConfidencesAndWeights();
    }

    private void initConfidencesAndWeights() {
        if (this.WRMFJudge == 0 || this.WRMFJudge == 2) {
            int itemIdx;
            double sumPopularity = 0.0;
            for (itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
                double alphaPopularity = Math.pow((double)this.trainMatrix.column(itemIdx).getNumEntries() * 1.0 / (double)this.numRates, this.ratio);
                this.confidences[itemIdx] = (double)this.overallWeight * alphaPopularity;
                sumPopularity += alphaPopularity;
            }
            for (itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
                this.confidences[itemIdx] = this.confidences[itemIdx] / sumPopularity;
            }
        } else {
            for (int itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
                this.confidences[itemIdx] = 1.0;
            }
        }
        this.weightMatrix();
    }

    public double weight(double value) {
        double weight = this.WRMFJudge == 1 || this.WRMFJudge == 2 ? 1.0 + (double)this.weightCoefficient * value : 1.0;
        return weight;
    }

    private void weightMatrix() {
        Double2DoubleOpenHashMap ratingWeightMap = new Double2DoubleOpenHashMap();
        Iterator<Object> iterator = ratingScale.iterator();
        while (iterator.hasNext()) {
            double rating = (Double)iterator.next();
            ratingWeightMap.putIfAbsent(rating, this.weight(rating));
        }
        for (MatrixEntry matrixEntry : this.trainMatrix) {
            matrixEntry.set(ratingWeightMap.get(matrixEntry.get()));
        }
    }

    @Override
    protected void trainModel() throws LibrecException {
        double[] usersPredictions = new double[this.numUsers];
        double[] itemsPredictions = new double[this.numItems];
        DenseMatrix itemFactorsCache = new DenseMatrix(this.numFactors, this.numFactors);
        this.userFactors = new DenseMatrix(this.numUsers, this.numFactors);
        for (int iter = 1; iter <= this.numIterations; ++iter) {
            double weight;
            double denom;
            int factorCacheIdx;
            for (int factorIdx1 = 0; factorIdx1 < this.numFactors; ++factorIdx1) {
                for (int factorIdx2 = 0; factorIdx2 <= factorIdx1; ++factorIdx2) {
                    double value = 0.0;
                    for (int itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
                        value += this.confidences[itemIdx] * this.itemFactors.get(itemIdx, factorIdx1) * this.itemFactors.get(itemIdx, factorIdx2);
                    }
                    itemFactorsCache.set(factorIdx1, factorIdx2, value);
                    itemFactorsCache.set(factorIdx2, factorIdx1, value);
                }
            }
            for (int userIdx = 0; userIdx < this.numUsers; ++userIdx) {
                SequentialSparseVector itemVector = this.trainMatrix.row(userIdx);
                for (Vector.VectorEntry vectorEntry : itemVector) {
                    int itemIdx = vectorEntry.index();
                    itemsPredictions[itemIdx] = this.userFactors.row(userIdx).dot(this.itemFactors.row(itemIdx));
                }
                for (factorCacheIdx = 0; factorCacheIdx < this.numFactors; ++factorCacheIdx) {
                    int itemIdx;
                    double numer = 0.0;
                    denom = (double)this.regUser + itemFactorsCache.get(factorCacheIdx, factorCacheIdx);
                    for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                        if (factorCacheIdx == factorIdx) continue;
                        numer -= this.userFactors.get(userIdx, factorIdx) * itemFactorsCache.get(factorCacheIdx, factorIdx);
                    }
                    for (Vector.VectorEntry vectorEntry : itemVector) {
                        itemIdx = vectorEntry.index();
                        weight = vectorEntry.get();
                        int n = itemIdx;
                        itemsPredictions[n] = itemsPredictions[n] - this.userFactors.get(userIdx, factorCacheIdx) * this.itemFactors.get(itemIdx, factorCacheIdx);
                        numer += (weight - (weight - this.confidences[itemIdx]) * itemsPredictions[itemIdx]) * this.itemFactors.get(itemIdx, factorCacheIdx);
                        denom += (weight - this.confidences[itemIdx]) * this.itemFactors.get(itemIdx, factorCacheIdx) * this.itemFactors.get(itemIdx, factorCacheIdx);
                    }
                    this.userFactors.set(userIdx, factorCacheIdx, numer / denom);
                    for (Vector.VectorEntry vectorEntry : itemVector) {
                        int n = itemIdx = vectorEntry.index();
                        itemsPredictions[n] = itemsPredictions[n] + this.userFactors.get(userIdx, factorCacheIdx) * this.itemFactors.get(itemIdx, factorCacheIdx);
                    }
                }
            }
            DenseMatrix userFactorsCache = this.userFactors.transpose().times(this.userFactors);
            for (int itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
                SequentialSparseVector userVector = this.trainMatrix.viewColumn(itemIdx);
                for (Vector.VectorEntry vectorEntry : userVector) {
                    int userIdx = vectorEntry.index();
                    usersPredictions[userIdx] = this.userFactors.row(userIdx).dot(this.itemFactors.row(itemIdx));
                }
                for (factorCacheIdx = 0; factorCacheIdx < this.numFactors; ++factorCacheIdx) {
                    int userIdx;
                    double numer = 0.0;
                    denom = this.confidences[itemIdx] * userFactorsCache.get(factorCacheIdx, factorCacheIdx) + (double)this.regItem;
                    for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                        if (factorCacheIdx == factorIdx) continue;
                        numer -= this.itemFactors.get(itemIdx, factorIdx) * userFactorsCache.get(factorIdx, factorCacheIdx);
                    }
                    numer *= this.confidences[itemIdx];
                    for (Vector.VectorEntry vectorEntry : userVector) {
                        userIdx = vectorEntry.index();
                        weight = vectorEntry.get();
                        int n = userIdx;
                        usersPredictions[n] = usersPredictions[n] - this.userFactors.get(userIdx, factorCacheIdx) * this.itemFactors.get(itemIdx, factorCacheIdx);
                        numer += (weight - (weight - this.confidences[itemIdx]) * usersPredictions[userIdx]) * this.userFactors.get(userIdx, factorCacheIdx);
                        denom += (weight - this.confidences[itemIdx]) * this.userFactors.get(userIdx, factorCacheIdx) * this.userFactors.get(userIdx, factorCacheIdx);
                    }
                    this.itemFactors.set(itemIdx, factorCacheIdx, numer / denom);
                    for (Vector.VectorEntry vectorEntry : userVector) {
                        int n = userIdx = vectorEntry.index();
                        usersPredictions[n] = usersPredictions[n] + this.userFactors.get(userIdx, factorCacheIdx) * this.itemFactors.get(itemIdx, factorCacheIdx);
                    }
                }
            }
            if (!verbose) continue;
            this.LOG.info(this.getClass() + " runs at iteration = " + iter + " " + new Date());
        }
    }
}

