/*
 * Decompiled with CFR 0.152.
 */
package net.librec.increment.rating;

import com.google.common.collect.Table;
import java.util.Iterator;
import net.librec.common.LibrecException;
import net.librec.increment.IncrementalRatingRecommender;
import net.librec.increment.TableMatrix;
import net.librec.math.structure.MatrixEntry;

public class UserItemBaseline
extends IncrementalRatingRecommender {
    public double regU = 15.0;
    public double regI = 10.0;
    public int numIter = 10;
    public double globalAverage;
    protected TableMatrix userBiases;
    protected TableMatrix itemBiases;

    @Override
    public void trainModel() throws LibrecException {
        this.userBiases = new TableMatrix(this.numUsers);
        this.itemBiases = new TableMatrix(this.numItems);
        this.globalAverage = this.globalMean;
        for (int iter = 0; iter < this.numIter; ++iter) {
            this.iterate();
        }
    }

    public void iterate() {
        this.optimizeItemBiases();
        this.optimizeUserBiases();
    }

    protected void optimizeUserBiases() {
        TableMatrix userRatingCount = new TableMatrix(this.maxUserId);
        this.userBiases.init(0.0);
        for (MatrixEntry matrixEntry : this.trainMatrix) {
            int userId = matrixEntry.row();
            int itemId = matrixEntry.column();
            double realRating = matrixEntry.get();
            double updatedBaise = realRating - this.globalAverage - this.itemBiases.get(itemId);
            this.userBiases.add(itemId, updatedBaise);
        }
        Iterator<Table.Cell<Integer, Integer, Double>> iterator = userRatingCount.iterator();
        while (iterator.hasNext()) {
            int itemId = iterator.next().getColumnKey();
            double biases = iterator.next().getValue();
            if (biases == 0.0) continue;
            this.userBiases.set(itemId, this.userBiases.get(itemId) / (this.regI + userRatingCount.get(itemId)));
        }
    }

    protected void optimizeItemBiases() {
        TableMatrix itemRatingCount = new TableMatrix(this.maxUserId);
        this.itemBiases.init(0.0);
        for (MatrixEntry matrixEntry : this.trainMatrix) {
            int userId = matrixEntry.row();
            int itemId = matrixEntry.column();
            double realRating = matrixEntry.get();
            double updatedBaise = realRating - this.globalAverage - this.userBiases.get(userId);
            this.itemBiases.add(itemId, updatedBaise);
            itemRatingCount.add(itemId, 1.0);
        }
        Iterator<Table.Cell<Integer, Integer, Double>> iterator = itemRatingCount.iterator();
        while (iterator.hasNext()) {
            int itemId = iterator.next().getColumnKey();
            double biases = iterator.next().getValue();
            if (biases == 0.0) continue;
            this.itemBiases.set(itemId, this.itemBiases.get(itemId) / (this.regI + itemRatingCount.get(itemId)));
        }
    }

    @Override
    public double predict(int userId, int itemId) {
        double itemBias;
        double userBias = userId < this.userBiases.columnSize() && userId >= 0 ? this.userBiases.get(userId) : 0.0;
        double result = this.globalAverage + userBias + (itemBias = userId < this.itemBiases.columnSize() && userId >= 0 ? this.itemBiases.get(userId) : 0.0);
        if (result > this.maxRating) {
            return this.maxRating;
        }
        if (result < this.minRating) {
            return this.minRating;
        }
        return result;
    }

    public void retrainUser(int userId) {
        if (this.updateItems) {
            int itemCounts = 0;
            for (MatrixEntry matrixEntry : this.trainMatrix) {
                if (userId != matrixEntry.row()) continue;
                int itemId = matrixEntry.column();
                double realRating = matrixEntry.get();
                this.userBiases.add(userId, realRating - this.globalAverage - this.itemBiases.get(userId));
                ++itemCounts;
            }
            if (itemCounts != 0) {
                this.userBiases.set(userId, this.userBiases.get(userId) / (this.regU + (double)itemCounts));
            }
        }
    }

    public void retrainItem(int itemId) {
        if (this.updateItems) {
            int userCounts = 0;
            for (MatrixEntry matrixEntry : this.trainMatrix) {
                if (itemId != matrixEntry.column()) continue;
                int userId = matrixEntry.column();
                double realRating = matrixEntry.get();
                this.itemBiases.add(itemId, realRating - this.globalAverage);
                ++userCounts;
            }
            if (userCounts != 0) {
                this.itemBiases.set(itemId, this.userBiases.get(itemId) / (this.regI + (double)userCounts));
            }
        }
    }

    @Override
    public void addRatings(TableMatrix newRatings) throws LibrecException {
        super.addRatings(newRatings);
        this.retrainUsersAndItems(newRatings);
    }

    @Override
    public void updateRatings(TableMatrix newRatings) throws LibrecException {
        super.updateRatings(newRatings);
        this.retrainUsersAndItems(newRatings);
    }

    @Override
    public void removeRatings(TableMatrix newRatings) throws LibrecException {
        super.removeRatings(newRatings);
        this.retrainUsersAndItems(newRatings);
    }

    @Override
    public void addUser(int userId) {
        super.addUser(userId);
    }

    public void addItems(int itemId) {
        super.addItem(itemId);
    }

    public void retrainUsersAndItems(TableMatrix newRatings) throws LibrecException {
        Iterator<Table.Cell<Integer, Integer, Double>> it = newRatings.iterator();
        while (it.hasNext()) {
            Table.Cell<Integer, Integer, Double> iterRatingData = it.next();
            int userId = iterRatingData.getRowKey();
            int itemId = iterRatingData.getColumnKey();
            this.retrainUser(userId);
            this.retrainItem(itemId);
        }
    }
}

