/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.cf.rating;

import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.MatrixEntry;
import net.librec.math.structure.SequentialSparseVector;
import net.librec.math.structure.VectorBasedDenseVector;
import net.librec.recommender.MatrixFactorizationRecommender;

@ModelData(value={"isRating", "nmf", "transUserFactors", "transItemFactors"})
public class NMFRecommender
extends MatrixFactorizationRecommender {
    private DenseMatrix transUserFactors;
    private DenseMatrix transItemFactors;
    protected int numFactors;
    protected int numIterations;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        this.numFactors = this.conf.getInt("rec.factor.number", 10);
        this.numIterations = this.conf.getInt("rec.iterator.maximum", 100);
        this.transUserFactors = new DenseMatrix(this.numFactors, this.numUsers);
        this.transItemFactors = new DenseMatrix(this.numFactors, this.numItems);
        this.transUserFactors.init(0.01);
        this.transItemFactors.init(0.01);
    }

    @Override
    protected void trainModel() throws LibrecException {
        for (int iter = 0; iter <= this.numIterations; ++iter) {
            double estmValue;
            double realValue;
            int factorIdx;
            int i;
            for (int userIdx = 0; userIdx < this.numUsers; ++userIdx) {
                SequentialSparseVector itemRatingsVector = this.trainMatrix.row(userIdx);
                if (itemRatingsVector.getNumEntries() <= 0) continue;
                VectorBasedDenseVector itemPredictsVector = new VectorBasedDenseVector(this.numItems);
                for (i = 0; i < itemRatingsVector.getNumEntries(); ++i) {
                    int itemIdx = itemRatingsVector.getIndexAtPosition(i);
                    itemPredictsVector.set(itemIdx, this.predict(userIdx, itemIdx));
                }
                for (factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                    VectorBasedDenseVector factorItemsVector = (VectorBasedDenseVector)this.transItemFactors.row(factorIdx);
                    realValue = factorItemsVector.dot(itemRatingsVector);
                    estmValue = factorItemsVector.dot(itemPredictsVector) + 1.0E-9;
                    this.transUserFactors.set(factorIdx, userIdx, this.transUserFactors.get(factorIdx, userIdx) * (realValue / estmValue));
                }
            }
            for (int itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
                SequentialSparseVector userRatingsVector = this.trainMatrix.column(itemIdx);
                if (userRatingsVector.getNumEntries() <= 0) continue;
                VectorBasedDenseVector userPredictsVector = new VectorBasedDenseVector(this.numUsers);
                for (i = 0; i < userRatingsVector.getNumEntries(); ++i) {
                    int userIdx = userRatingsVector.getIndexAtPosition(i);
                    userPredictsVector.set(userIdx, this.predict(userIdx, itemIdx));
                }
                for (factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                    VectorBasedDenseVector factorUsersVector = (VectorBasedDenseVector)this.transUserFactors.row(factorIdx);
                    realValue = factorUsersVector.dot(userRatingsVector);
                    estmValue = factorUsersVector.dot(userPredictsVector) + 1.0E-9;
                    this.transItemFactors.set(factorIdx, itemIdx, this.transItemFactors.get(factorIdx, itemIdx) * (realValue / estmValue));
                }
            }
            this.loss = 0.0;
            for (MatrixEntry matrixEntry : this.trainMatrix) {
                int userIdx = matrixEntry.row();
                int itemIdx = matrixEntry.column();
                double rating = matrixEntry.get();
                if (!(rating > 0.0)) continue;
                double ratingError = this.predict(userIdx, itemIdx) - rating;
                this.loss += ratingError * ratingError;
            }
            this.loss *= 0.5;
            if (this.isConverged(iter) && this.earlyStop) break;
            this.lastLoss = this.loss;
        }
    }

    @Override
    protected double predict(int userIdx, int itemIdx) throws LibrecException {
        return this.transUserFactors.column(userIdx).dot(this.transItemFactors.column(itemIdx));
    }
}

