/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.baseline;

import com.google.common.collect.Sets;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.ArrayList;
import java.util.TreeSet;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.DenseVector;
import net.librec.math.structure.SequentialSparseVector;
import net.librec.math.structure.Vector;
import net.librec.math.structure.VectorBasedDenseVector;
import net.librec.recommender.MatrixProbabilisticGraphicalRecommender;

public class UserClusterRecommender
extends MatrixProbabilisticGraphicalRecommender {
    private DenseMatrix topicRatingProbs;
    private DenseVector topicInitialProbs;
    private DenseMatrix userTopicProbs;
    private DenseMatrix userNumEachRating;
    private DenseVector userNumRatings;
    private int numTopics;
    private int numRatingLevels;
    private double lastLoss;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        TreeSet<Double> ratingScaleSet = Sets.newTreeSet(this.trainMatrix.getDataTable().values());
        ratingScale = new ArrayList<Double>(ratingScaleSet);
        this.numRatingLevels = ratingScale.size();
        this.numTopics = this.conf.getInt("rec.factory.number", 10);
        this.topicRatingProbs = new DenseMatrix(this.numTopics, this.numRatingLevels);
        for (int k = 0; k < this.numTopics; ++k) {
            double[] probs = Randoms.randProbs(this.numRatingLevels);
            for (int r = 0; r < this.numRatingLevels; ++r) {
                this.topicRatingProbs.set(k, r, probs[r]);
            }
        }
        this.topicInitialProbs = new VectorBasedDenseVector(Randoms.randProbs(this.numTopics));
        this.userTopicProbs = new DenseMatrix(this.numUsers, this.numTopics);
        this.userNumEachRating = new DenseMatrix(this.numUsers, this.numRatingLevels);
        this.userNumRatings = new VectorBasedDenseVector(this.numUsers);
        for (int u = 0; u < this.numUsers; ++u) {
            SequentialSparseVector ru = this.trainMatrix.row(u);
            for (Vector.VectorEntry ve : ru) {
                double rui = ve.get();
                int r = ratingScale.indexOf(rui);
                this.userNumEachRating.plus(u, r, 1.0);
            }
            this.userNumRatings.set(u, ru.size());
        }
        this.lastLoss = Double.MIN_VALUE;
    }

    @Override
    protected void eStep() {
        for (int u = 0; u < this.numUsers; ++u) {
            int k;
            BigDecimal sum_u = BigDecimal.ZERO;
            SequentialSparseVector ru = this.trainMatrix.row(u);
            BigDecimal[] sum_uk = new BigDecimal[this.numTopics];
            for (k = 0; k < this.numTopics; ++k) {
                BigDecimal userTopicProb = new BigDecimal(this.topicInitialProbs.get(k));
                for (Vector.VectorEntry ve : ru) {
                    double rui = ve.get();
                    int r = ratingScale.indexOf(rui);
                    BigDecimal topicRatingProb = new BigDecimal(this.topicRatingProbs.get(k, r));
                    userTopicProb = userTopicProb.multiply(topicRatingProb);
                }
                sum_uk[k] = userTopicProb;
                sum_u = sum_u.add(userTopicProb);
            }
            for (k = 0; k < this.numTopics; ++k) {
                double zuk = sum_uk[k].divide(sum_u, 6, RoundingMode.HALF_UP).doubleValue();
                this.userTopicProbs.set(u, k, zuk);
            }
        }
    }

    @Override
    protected void mStep() {
        int k;
        double[] sum_uk = new double[this.numTopics];
        double sum2 = 0.0;
        for (k = 0; k < this.numTopics; ++k) {
            for (int r = 0; r < this.numRatingLevels; ++r) {
                double numerator = 0.0;
                double denorminator = 0.0;
                for (int u = 0; u < this.numUsers; ++u) {
                    double ruk = this.userTopicProbs.get(u, k);
                    numerator += ruk * this.userNumEachRating.get(u, r);
                    denorminator += ruk * this.userNumRatings.get(u);
                }
                this.topicRatingProbs.set(k, r, numerator / denorminator);
            }
            double sum_u = 0.0;
            for (int u = 0; u < this.numUsers; ++u) {
                double ruk = this.userTopicProbs.get(u, k);
                sum_u += ruk;
            }
            sum_uk[k] = sum_u;
            sum2 += sum_u;
        }
        for (k = 0; k < this.numTopics; ++k) {
            this.topicInitialProbs.set(k, sum_uk[k] / sum2);
        }
    }

    @Override
    protected boolean isConverged(int iter) {
        double loss = 0.0;
        for (int u = 0; u < this.numUsers; ++u) {
            for (int k = 0; k < this.numTopics; ++k) {
                double ruk = this.userTopicProbs.get(u, k);
                double pi_k = this.topicInitialProbs.get(k);
                double sum_nl = 0.0;
                for (int r = 0; r < this.numRatingLevels; ++r) {
                    double nur = this.userNumEachRating.get(u, r);
                    double pkr = this.topicRatingProbs.get(k, r);
                    sum_nl += nur * Math.log(pkr);
                }
                loss += ruk * (Math.log(pi_k) + sum_nl);
            }
        }
        float deltaLoss = (float)(loss - this.lastLoss);
        if (iter > 1 && (deltaLoss > 0.0f || Double.isNaN(deltaLoss))) {
            return true;
        }
        this.lastLoss = loss;
        return false;
    }

    @Override
    protected double predict(int userIdx, int itemIdx) throws LibrecException {
        double pred = 0.0;
        for (int k = 0; k < this.numTopics; ++k) {
            double pu_k = this.userTopicProbs.get(userIdx, k);
            double pred_k = 0.0;
            for (int r = 0; r < this.numRatingLevels; ++r) {
                double ruj = (Double)ratingScale.get(r);
                double pkr = this.topicRatingProbs.get(k, r);
                pred_k += ruj * pkr;
            }
            pred += pu_k * pred_k;
        }
        return pred;
    }
}

