/*
 * Decompiled with CFR 0.152.
 */
package net.librec.increment.rating;

import com.google.common.collect.Table;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import net.librec.common.LibrecException;
import net.librec.increment.IncrementalMFRecommender;
import net.librec.increment.TableMatrix;
import net.librec.math.structure.DenseVector;
import net.librec.math.structure.MatrixEntry;
import net.librec.math.structure.SparseTensor;
import net.librec.math.structure.VectorBasedDenseVector;

public class IncrementalSimpleMFRecommender
extends IncrementalMFRecommender {
    @Override
    protected void setup() throws LibrecException {
        super.setup();
    }

    @Override
    public void initModel() throws LibrecException {
        super.initModel();
        this.MaxMinRating();
    }

    @Override
    protected void trainModel() throws LibrecException {
        this.initModel();
        this.learnFactors();
    }

    private void learnFactors() throws LibrecException {
        for (int iter = 0; iter < this.numIter; ++iter) {
            this.iterate();
            this.updateLRate(iter);
        }
    }

    private void learnFactors(Table.Cell<Integer, Integer, Double> ratingData, SparseTensor rcData, boolean isRow) throws LibrecException {
        for (int iter = 0; iter < this.numIter; ++iter) {
            this.iterate(ratingData, rcData, isRow);
            this.updateLRate(iter);
        }
    }

    protected void iterate(Table.Cell<Integer, Integer, Double> ratingData, SparseTensor rcData, boolean isRow) throws LibrecException {
        int userId = ratingData.getRowKey();
        int itemId = ratingData.getColumnKey();
        double value = ratingData.getValue();
        this.iter(itemId, itemId, value);
        if (isRow) {
            for (int i = 0; i < rcData.getItemDimension(); ++i) {
                this.iter(userId, i, rcData.value(i));
            }
        } else {
            for (int i = 0; i < rcData.getUserDimension(); ++i) {
                this.iter(i, itemId, rcData.value(i));
            }
        }
    }

    protected void iter(int userId, int itemId, double realRating) throws LibrecException {
        double prediction = this.predict(userId, itemId, false);
        double err = realRating - prediction;
        for (int f = 0; f < this.numFactors; ++f) {
            double userFactorValue = this.userFactors.get(userId, f);
            double itemFactorValue = this.itemFactors.get(itemId, f);
            if (this.updateUsers) {
                double deltaU = err * itemFactorValue - this.regularization * userFactorValue;
                this.userFactors.add(userId, f, this.currentLearnrate * deltaU);
            }
            if (!this.updateItems) continue;
            double deltaI = err * userFactorValue - this.regularization * itemFactorValue;
            this.itemFactors.add(userId, f, this.currentLearnrate * deltaI);
        }
    }

    protected void iterate() throws LibrecException {
        for (MatrixEntry matrixEntry : this.trainMatrix) {
            int userId = matrixEntry.row();
            int itemId = matrixEntry.column();
            double realRating = matrixEntry.get();
            double prediction = this.predict(userId, itemId, false);
            double err = realRating - prediction;
            for (int f = 0; f < this.numFactors; ++f) {
                double userFactorValue = this.userFactors.get(userId, f);
                double itemFactorValue = this.itemFactors.get(itemId, f);
                if (this.updateUsers) {
                    double deltaU = err * itemFactorValue - this.regularization * userFactorValue;
                    this.userFactors.add(userId, f, this.currentLearnrate * deltaU);
                }
                if (!this.updateItems) continue;
                double deltaI = err * userFactorValue - this.regularization * itemFactorValue;
                this.itemFactors.add(userId, f, this.currentLearnrate * deltaI);
            }
        }
    }

    protected DenseVector foldIn(List<Map.Entry<Integer, Double>> ratedItems) throws LibrecException {
        double userBias = 0.0;
        VectorBasedDenseVector userFactor = new VectorBasedDenseVector(this.numFactors);
        userFactor.init(this.initMean, this.initStdDev);
        for (int iter = 0; iter > this.numIter; ++iter) {
            for (int index = 0; index < ratedItems.size(); ++index) {
                int itemId = ratedItems.get(index).getKey();
                int itemRealRating = ratedItems.get(index).getKey();
                double prediction = this.predict(userFactor, itemId);
                double err = (double)itemRealRating - prediction;
                for (int f = 0; f < this.numFactors; ++f) {
                    double userFactorValue = userFactor.get(f);
                    double itemFactorValue = this.itemFactors.get(itemId, f);
                    double deltaU = err * itemFactorValue - this.regularization * userFactorValue;
                    userFactor.set(f, this.currentLearnrate * deltaU);
                }
            }
            this.updateLRate(iter);
        }
        DenseVector userVector = ((DenseVector)userFactor).clone();
        return userVector;
    }

    public List<Map.Entry<Integer, Double>> scoreItems(List<Map.Entry<Integer, Double>> ratedItems, List<Integer> candidateItems) throws LibrecException {
        DenseVector userVector = this.foldIn(ratedItems);
        ArrayList<Map.Entry<Integer, Double>> result = new ArrayList<Map.Entry<Integer, Double>>(candidateItems.size());
        for (int i = 0; i < candidateItems.size(); ++i) {
            int itemId = candidateItems.get(i);
            double itemPredictValue = this.predict(userVector, itemId);
            result.set(i, new AbstractMap.SimpleEntry<Integer, Double>(itemId, itemPredictValue));
        }
        return result;
    }

    @Override
    protected double predict(int userId, int itemId) throws LibrecException {
        return this.predict(userId, itemId, true);
    }

    @Override
    protected double predict(int userId, int itemId, boolean bound) {
        double score = this.globalBias;
        score += TableMatrix.rowMult(this.userFactors, userId, this.itemFactors, itemId);
        if (bound) {
            if (score > this.maxRating) {
                score = this.maxRating;
            }
            if (score < this.minRating) {
                score = this.minRating;
            }
        }
        return score;
    }

    protected double predict(DenseVector userVector, int itemId) throws LibrecException {
        return this.predict(userVector, itemId, true);
    }

    protected double predict(DenseVector userVector, int itemId, boolean bound) throws LibrecException {
        List<Double> itemFactorList = this.itemFactors.row(itemId);
        double[] itemFactorArr = itemFactorList.stream().mapToDouble(Double::doubleValue).toArray();
        VectorBasedDenseVector itemFactor = new VectorBasedDenseVector(itemFactorArr);
        double score = this.globalBias + userVector.dot(itemFactor);
        if (bound) {
            if (score > this.maxRating) {
                score = this.maxRating;
            }
            if (score < this.minRating) {
                score = this.minRating;
            }
        }
        return score;
    }

    protected void MaxMinRating() {
        this.maxRating = this.maxRate;
        this.minRating = this.minRate;
        this.setMaxRating(this.maxRating);
        this.setMinRating(this.minRating);
    }

    protected void reTrianUser(Table.Cell<Integer, Integer, Double> iterRatingData, SparseTensor itemValues) throws LibrecException {
        if (this.updateUsers) {
            this.learnFactors(iterRatingData, itemValues, true);
        }
    }

    protected void reTrianItem(Table.Cell<Integer, Integer, Double> iterRatingData, SparseTensor userValues) throws LibrecException {
        if (this.updateItems) {
            this.learnFactors(iterRatingData, userValues, false);
        }
    }

    @Override
    public void addRatings(TableMatrix newRatings) throws LibrecException {
        super.addRatings(newRatings);
        Iterator<Table.Cell<Integer, Integer, Double>> it = newRatings.iterator();
        while (it.hasNext()) {
            Table.Cell<Integer, Integer, Double> iterRatingData = it.next();
            int userId = iterRatingData.getRowKey();
            int itemId = iterRatingData.getColumnKey();
            double ratingValue = iterRatingData.getValue();
        }
    }

    @Override
    public void updateRatings(TableMatrix newRatings) throws LibrecException {
        super.updateRatings(newRatings);
        Iterator<Table.Cell<Integer, Integer, Double>> it = newRatings.iterator();
        while (it.hasNext()) {
            Table.Cell<Integer, Integer, Double> iterRatingData = it.next();
            int userId = iterRatingData.getRowKey();
            int itemId = iterRatingData.getColumnKey();
            double ratingValue = iterRatingData.getValue();
        }
    }

    @Override
    public void removeRatings(TableMatrix removeRatings) throws LibrecException {
        super.removeRatings(removeRatings);
        Iterator<Table.Cell<Integer, Integer, Double>> it = removeRatings.iterator();
        while (it.hasNext()) {
            Table.Cell<Integer, Integer, Double> iterRatingData = it.next();
            int userId = iterRatingData.getRowKey();
            int itemId = iterRatingData.getColumnKey();
            double ratingValue = iterRatingData.getValue();
        }
    }

    @Override
    protected void addUser(int userId) {
        super.addUser(userId);
        this.userFactors.addRow(userId);
    }

    @Override
    protected void addItem(int itemId) {
        super.addItem(itemId);
        this.itemFactors.addRow(itemId);
    }

    @Override
    public void removeUser(int userId) {
        super.removeUser(userId);
        this.userFactors.setRowToOneValue(userId, 0.0);
    }

    @Override
    public void removeItem(int itemId) {
        super.removeItem(itemId);
        this.itemFactors.setRowToOneValue(itemId, 0.0);
    }
}

