/*
 * Decompiled with CFR 0.152.
 */
package net.librec.increment.rating;

import com.google.common.collect.Table;
import java.util.List;
import java.util.Map;
import net.librec.common.LibrecException;
import net.librec.increment.IncrementalRatingRecommender;
import net.librec.increment.TableMatrix;
import net.librec.increment.rating.IncrementalSimpleMFRecommender;
import net.librec.math.structure.DenseVector;
import net.librec.math.structure.MatrixEntry;
import net.librec.math.structure.SparseTensor;
import net.librec.math.structure.VectorBasedDenseVector;

public class IncrementalBiasedMFRecommender
extends IncrementalSimpleMFRecommender {
    protected double maxRating;
    protected double minRating;
    protected double ratingRangeSize;
    public int maxThreads;
    protected IncrementalRatingRecommender.OptimizationTarget lossTarget;
    protected boolean frequencyRegularization = false;
    protected double regU;
    protected double regI;
    protected double biasLearnReg = 1.0;
    protected double biasReg = 0.01;
    protected String optTarget = "rmse";
    protected final int FOLD_IN_BIAS_INDEX = 0;
    protected final int FOLD_IN_FACTORS_START = 1;
    protected TableMatrix userBiases;
    protected TableMatrix itemBiases;

    @Override
    protected void setup() throws LibrecException {
        this.getGlobalBias(this.globalMean);
        super.setup();
    }

    @Override
    public void initModel() throws LibrecException {
        super.initModel();
        this.userBiases = new TableMatrix(this.numUsers);
        this.itemBiases = new TableMatrix(this.numItems);
        this.userBiases.init(this.initMean, this.initStd);
        this.itemBiases.init(this.initMean, this.initStd);
    }

    @Override
    public void trainModel() throws LibrecException {
        this.initModel();
        this.ratingRangeSize = this.maxRating - this.minRating;
        double avg = (this.globalMean - this.minRating) / (this.maxRating - this.minRating);
        this.globalBias = Math.log(avg / (1.0 - avg));
        for (int iter = 0; iter < this.numIter; ++iter) {
            this.iterate(this.updateUsers, this.updateItems);
            this.updateLRate(iter);
        }
    }

    protected void iterate(boolean updateUser, boolean updateItem) throws LibrecException {
        for (MatrixEntry matrixEntry : this.trainMatrix) {
            double itemRegWeight;
            int userId = matrixEntry.row();
            int itemId = matrixEntry.column();
            double realRating = matrixEntry.get();
            double score = this.globalBias + this.userBiases.get(userId) + this.itemBiases.get(itemId) + TableMatrix.rowMult(this.userFactors, userId, this.itemFactors, itemId);
            double sigScore = 1.0 / (1.0 + Math.exp(score));
            double prediction = this.minRating + sigScore * this.ratingRangeSize;
            double err = realRating - prediction;
            double gradientCommon = this.computeGradientCommon(sigScore, err);
            double userRegWeight = this.frequencyRegularization ? this.regU / Math.sqrt(this.trainMatrix.rowSize()) : this.regU;
            double d = itemRegWeight = this.frequencyRegularization ? this.regI / Math.sqrt(this.trainMatrix.columnSize()) : this.regI;
            if (updateUser) {
                this.userBiases.add(userId, this.biasLearnReg * this.currentLearnrate * (gradientCommon - this.biasReg * userRegWeight * this.userBiases.get(userId)));
            }
            if (updateItem) {
                this.itemBiases.add(itemId, this.biasLearnReg * this.currentLearnrate * (gradientCommon - this.biasReg * itemRegWeight * this.itemBiases.get(itemId)));
            }
            for (int f = 0; f < this.numFactors; ++f) {
                double userFactorValue = this.userFactors.get(userId, f);
                double itemFactorValue = this.itemFactors.get(itemId, f);
                if (updateUser) {
                    double deltaU = gradientCommon * itemFactorValue - userRegWeight * userFactorValue;
                    this.userFactors.add(userId, f, this.currentLearnrate * deltaU);
                }
                if (!updateItem) continue;
                double deltaI = gradientCommon * userFactorValue - itemRegWeight * itemFactorValue;
                this.itemFactors.add(userId, f, this.currentLearnrate * deltaI);
            }
        }
    }

    protected double computeGradientCommon(double sigScore, double err) {
        this.lossTarget = this.lossTarget(this.optTarget);
        return this.setupLoss(sigScore, err);
    }

    protected double setupLoss(double sigScore, double err) {
        double gradientCommon = 0.0;
        switch (this.lossTarget) {
            case RMSE: {
                gradientCommon = Math.signum(err) * sigScore * (1.0 - sigScore) * this.ratingRangeSize;
            }
            case MSE: {
                gradientCommon = err * sigScore * (1.0 - sigScore) * this.ratingRangeSize;
            }
            case LogisticLoss: {
                gradientCommon = err;
            }
        }
        gradientCommon = Math.signum(err) * sigScore * (1.0 - sigScore) * this.ratingRangeSize;
        return gradientCommon;
    }

    @Override
    protected DenseVector foldIn(List<Map.Entry<Integer, Double>> ratedItems) throws LibrecException {
        double userBias = 0.0;
        VectorBasedDenseVector userFactor = new VectorBasedDenseVector(this.numFactors);
        userFactor.init(this.initMean, this.initStdDev);
        double userRegWeight = this.frequencyRegularization ? this.regU / Math.sqrt(ratedItems.size()) : this.regU;
        for (int iter = 0; iter > this.numIter; ++iter) {
            for (int index = 0; index < ratedItems.size(); ++index) {
                int itemId = ratedItems.get(index).getKey();
                int itemRealRating = ratedItems.get(index).getKey();
                List<Double> itemFactorList = this.itemFactors.row(itemId);
                double[] itemFactorArr = itemFactorList.stream().mapToDouble(Double::doubleValue).toArray();
                VectorBasedDenseVector itemFactor = new VectorBasedDenseVector(itemFactorArr);
                double score = this.globalBias + userBias + this.itemBiases.get(itemId) + userFactor.dot(itemFactor);
                double sigScore = 1.0 / (1.0 + Math.exp(score));
                double prediction = this.minRating + sigScore * this.ratingRangeSize;
                double err = (double)itemRealRating - prediction;
                double gradientCommon = this.computeGradientCommon(sigScore, err);
                userBias += this.biasLearnReg * this.learnRate * (gradientCommon - this.biasReg * userRegWeight * userBias);
                for (int f = 0; f < this.numFactors; ++f) {
                    double userFactorValue = userFactor.get(f);
                    double itemFactorValue = this.itemFactors.get(itemId, f);
                    double deltaU = gradientCommon * itemFactorValue - userRegWeight * userFactorValue;
                    userFactor.set(f, this.learnRate * deltaU);
                }
            }
        }
        DenseVector userVector = ((DenseVector)userFactor).clone();
        return userVector;
    }

    @Override
    protected double predict(int userId, int itemId) throws LibrecException {
        double score = this.globalBias;
        if (userId < this.numUsers) {
            score += this.userBiases.get(userId);
        }
        if (itemId < this.numIter) {
            score += this.itemBiases.get(itemId);
        }
        if (userId < this.userFactors.rowSize() && itemId < this.itemFactors.columnSize()) {
            score += TableMatrix.rowMult(this.userFactors, userId, this.itemFactors, itemId);
        }
        return this.minRating + 1.0 / (1.0 + Math.exp(-score)) * this.ratingRangeSize;
    }

    @Override
    protected double predict(DenseVector userVector, int itemId) throws LibrecException {
        DenseVector userFactor = userVector;
        double score = this.globalBias + userVector.get(0);
        if (itemId < this.numUsers) {
            List<Double> itemFactorList = this.itemFactors.row(itemId);
            double[] itemFactorArr = itemFactorList.stream().mapToDouble(Double::doubleValue).toArray();
            VectorBasedDenseVector itemFactor = new VectorBasedDenseVector(itemFactorArr);
            score += userFactor.dot(itemFactor);
        }
        return this.minRating + 1.0 / (1.0 + Math.exp(-score) * this.ratingRangeSize);
    }

    @Override
    public double getGlobalBias(double globalMean) {
        double avg = (globalMean - this.minRating) / (this.maxRating - this.minRating);
        this.globalBias = Math.log(avg / (1.0 - avg));
        return this.globalBias;
    }

    @Override
    protected void reTrianUser(Table.Cell<Integer, Integer, Double> iterRatingData, SparseTensor itemValues) throws LibrecException {
        int userId = iterRatingData.getRowKey();
        this.userBiases.set(userId, 0.0);
        super.reTrianUser(iterRatingData, itemValues);
    }

    @Override
    protected void reTrianItem(Table.Cell<Integer, Integer, Double> iterRatingData, SparseTensor itemValues) throws LibrecException {
        int itemId = iterRatingData.getColumnKey();
        this.itemFactors.set(itemId, 0.0);
        super.reTrianUser(iterRatingData, itemValues);
    }

    @Override
    protected void addUser(int userId) {
        super.addUser(userId);
        this.userBiases.add(this.maxUserId + 1);
    }

    @Override
    protected void addItem(int itemId) {
        super.addItem(itemId);
        this.itemBiases.add(this.maxItemId + 1);
    }

    @Override
    public void removeUser(int userId) {
        this.userBiases.set(userId, 0.0);
        super.removeUser(userId);
    }

    @Override
    public void removeItem(int itemId) {
        this.itemBiases.set(itemId, 0.0);
        super.removeItem(itemId);
    }
}

