/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.cf.ranking;

import it.unimi.dsi.fastutil.doubles.Double2DoubleOpenHashMap;
import java.util.ArrayList;
import java.util.Date;
import java.util.Iterator;
import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.DenseVector;
import net.librec.math.structure.MatrixEntry;
import net.librec.math.structure.SequentialSparseVector;
import net.librec.math.structure.Vector;
import net.librec.math.structure.VectorBasedDenseVector;
import net.librec.recommender.MatrixFactorizationRecommender;

@ModelData(value={"isRanking", "wrmf", "userFactors", "itemFactors", "trainMatrix"})
public class WRMFRecommender
extends MatrixFactorizationRecommender {
    protected float weightCoefficient;

    @Override
    public void setup() throws LibrecException {
        super.setup();
        this.weightCoefficient = this.conf.getFloat("rec.wrmf.weight.coefficient", Float.valueOf(4.0f)).floatValue();
        this.weightMatrix();
    }

    public double weight(double value) {
        return Math.log(1.0 + Math.pow(10.0, this.weightCoefficient) * value);
    }

    public void weightMatrix() {
        Double2DoubleOpenHashMap ratingWeightMap = new Double2DoubleOpenHashMap();
        Iterator<Object> iterator = ratingScale.iterator();
        while (iterator.hasNext()) {
            double rating = (Double)iterator.next();
            ratingWeightMap.putIfAbsent(rating, this.weight(rating));
        }
        for (MatrixEntry matrixEntry : this.trainMatrix) {
            matrixEntry.set(ratingWeightMap.get(matrixEntry.get()));
        }
    }

    @Override
    protected void trainModel() throws LibrecException {
        DenseMatrix X = this.userFactors;
        DenseMatrix Y = this.itemFactors;
        ArrayList<Integer> userList = new ArrayList<Integer>(this.numUsers);
        ArrayList<Integer> itemList = new ArrayList<Integer>(this.numItems);
        for (int userIndex2 = 0; userIndex2 < this.numUsers; ++userIndex2) {
            userList.add(userIndex2);
        }
        for (int itemIndex2 = 0; itemIndex2 < this.numItems; ++itemIndex2) {
            itemList.add(itemIndex2);
        }
        for (int iter = 1; iter <= this.numIterations; ++iter) {
            DenseMatrix YtY = Y.transpose().times(Y);
            userList.parallelStream().forEach(userIndex -> {
                DenseVector itemFactorVector;
                double weight;
                int itemIndex;
                SequentialSparseVector itemRatingVector = this.trainMatrix.row((int)userIndex);
                DenseMatrix factorMatrix = new DenseMatrix(this.numFactors, this.numFactors);
                VectorBasedDenseVector YtCuPu = new VectorBasedDenseVector(this.numFactors);
                for (Vector.VectorEntry vectorEntry : itemRatingVector) {
                    itemIndex = vectorEntry.index();
                    weight = vectorEntry.get() + 1.0;
                    itemFactorVector = this.itemFactors.row(itemIndex);
                    for (int factorIndex = 0; factorIndex < this.numFactors; ++factorIndex) {
                        YtCuPu.plus(factorIndex, itemFactorVector.get(factorIndex) * weight);
                    }
                }
                factorMatrix.assign((rowIndex, columnIndex, value) -> YtY.get(rowIndex, columnIndex) + (double)this.regUser);
                for (Vector.VectorEntry vectorEntry : itemRatingVector) {
                    itemIndex = vectorEntry.index();
                    weight = vectorEntry.get();
                    itemFactorVector = this.itemFactors.row(itemIndex);
                    for (int rowIndex2 = 0; rowIndex2 < this.numFactors; ++rowIndex2) {
                        double temp = itemFactorVector.get(rowIndex2) * weight;
                        for (int columnIndex2 = 0; columnIndex2 < this.numFactors; ++columnIndex2) {
                            factorMatrix.plus(rowIndex2, columnIndex2, temp * itemFactorVector.get(columnIndex2));
                        }
                    }
                }
                DenseMatrix Wu = factorMatrix.inverse();
                DenseVector xu = Wu.times(YtCuPu);
                X.set((int)userIndex, xu);
            });
            DenseMatrix XtX = X.transpose().times(X);
            itemList.parallelStream().forEach(itemIndex -> {
                DenseVector userFactorVector;
                double weight;
                int userIndex;
                SequentialSparseVector userRatingVector = this.trainMatrix.viewColumn((int)itemIndex);
                VectorBasedDenseVector XtCiPu = new VectorBasedDenseVector(this.numFactors);
                DenseMatrix factorMatrix = new DenseMatrix(this.numFactors, this.numFactors);
                for (Vector.VectorEntry vectorEntry : userRatingVector) {
                    userIndex = vectorEntry.index();
                    weight = vectorEntry.get() + 1.0;
                    userFactorVector = this.userFactors.row(userIndex);
                    for (int factorIndex = 0; factorIndex < this.numFactors; ++factorIndex) {
                        XtCiPu.plus(factorIndex, userFactorVector.get(factorIndex) * weight);
                    }
                }
                factorMatrix.assign((rowIndex, columnIndex, value) -> XtX.get(rowIndex, columnIndex) + (double)this.regItem);
                for (Vector.VectorEntry vectorEntry : userRatingVector) {
                    userIndex = vectorEntry.index();
                    weight = vectorEntry.get();
                    userFactorVector = this.userFactors.row(userIndex);
                    for (int rowIndex2 = 0; rowIndex2 < this.numFactors; ++rowIndex2) {
                        double temp = userFactorVector.get(rowIndex2) * weight;
                        for (int columnIndex2 = 0; columnIndex2 < this.numFactors; ++columnIndex2) {
                            factorMatrix.plus(rowIndex2, columnIndex2, temp * userFactorVector.get(columnIndex2));
                        }
                    }
                }
                DenseMatrix Wi = factorMatrix.inverse();
                DenseVector yi = Wi.times(XtCiPu);
                Y.set((int)itemIndex, yi);
            });
            if (!verbose) continue;
            this.LOG.info(this.getClass() + " runs at iteration = " + iter + " " + new Date());
        }
    }
}

