/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.baseline;

import com.google.common.collect.Sets;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.ArrayList;
import java.util.Collections;
import java.util.TreeSet;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.DenseVector;
import net.librec.math.structure.SequentialSparseVector;
import net.librec.math.structure.Vector;
import net.librec.math.structure.VectorBasedDenseVector;
import net.librec.recommender.MatrixProbabilisticGraphicalRecommender;

public class ItemClusterRecommender
extends MatrixProbabilisticGraphicalRecommender {
    private DenseMatrix topicRatingProbs;
    private DenseVector topicInitialProbs;
    private DenseMatrix itemTopicProbs;
    private DenseMatrix itemNumEachRating;
    private DenseVector itemNumRatings;
    private int numTopics;
    private int numRatingLevels;
    private double lastLoss;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        this.isRanking = false;
        TreeSet<Double> ratingScaleSet = Sets.newTreeSet(this.trainMatrix.getDataTable().values());
        ratingScale = new ArrayList<Double>(ratingScaleSet);
        this.numRatingLevels = ratingScale.size();
        Collections.sort(ratingScale);
        this.numTopics = this.conf.getInt("rec.pgm.number", 10);
        this.topicRatingProbs = new DenseMatrix(this.numTopics, this.numRatingLevels);
        for (int k = 0; k < this.numTopics; ++k) {
            double[] probs = Randoms.randProbs(this.numRatingLevels);
            for (int r = 0; r < this.numRatingLevels; ++r) {
                this.topicRatingProbs.set(k, r, probs[r]);
            }
        }
        this.topicInitialProbs = new VectorBasedDenseVector(Randoms.randProbs(this.numTopics));
        this.itemTopicProbs = new DenseMatrix(this.numItems, this.numTopics);
        this.itemNumEachRating = new DenseMatrix(this.numItems, this.numRatingLevels);
        this.itemNumRatings = new VectorBasedDenseVector(this.numItems);
        for (int i = 0; i < this.numItems; ++i) {
            SequentialSparseVector ri = this.trainMatrix.column(i);
            for (Vector.VectorEntry vi : ri) {
                double rui = vi.get();
                int r = ratingScale.indexOf(rui);
                this.itemNumEachRating.plus(i, r, 1.0);
            }
            this.itemNumRatings.set(i, ri.size());
        }
        this.lastLoss = Double.MIN_VALUE;
    }

    @Override
    protected void eStep() {
        for (int i = 0; i < this.numItems; ++i) {
            int k;
            BigDecimal sum_i = BigDecimal.ZERO;
            SequentialSparseVector ri = this.trainMatrix.column(i);
            BigDecimal[] sum_ik = new BigDecimal[this.numTopics];
            for (k = 0; k < this.numTopics; ++k) {
                BigDecimal itemTopicProb = new BigDecimal(this.topicInitialProbs.get(k));
                for (Vector.VectorEntry vi : ri) {
                    double rui = vi.get();
                    int r = ratingScale.indexOf(rui);
                    BigDecimal topicRatingProb = new BigDecimal(this.topicRatingProbs.get(k, r));
                    itemTopicProb = itemTopicProb.multiply(topicRatingProb);
                }
                sum_ik[k] = itemTopicProb;
                sum_i = sum_i.add(itemTopicProb);
            }
            for (k = 0; k < this.numTopics; ++k) {
                double zik = sum_ik[k].divide(sum_i, 6, RoundingMode.HALF_UP).doubleValue();
                this.itemTopicProbs.set(i, k, zik);
            }
        }
    }

    @Override
    protected void mStep() {
        int k;
        double[] sum_ik = new double[this.numTopics];
        double sum2 = 0.0;
        for (k = 0; k < this.numTopics; ++k) {
            for (int r = 0; r < this.numRatingLevels; ++r) {
                double numerator = 0.0;
                double denorminator = 0.0;
                for (int i = 0; i < this.numItems; ++i) {
                    double ruk = this.itemTopicProbs.get(i, k);
                    numerator += ruk * this.itemNumEachRating.get(i, r);
                    denorminator += ruk * this.itemNumRatings.get(i);
                }
                this.topicRatingProbs.set(k, r, numerator / denorminator);
            }
            double sum_i = 0.0;
            for (int i = 0; i < this.numItems; ++i) {
                double ruk = this.itemTopicProbs.get(i, k);
                sum_i += ruk;
            }
            sum_ik[k] = sum_i;
            sum2 += sum_i;
        }
        for (k = 0; k < this.numTopics; ++k) {
            this.topicInitialProbs.set(k, sum_ik[k] / sum2);
        }
    }

    @Override
    protected boolean isConverged(int iter) {
        double loss = 0.0;
        for (int i = 0; i < this.numItems; ++i) {
            for (int k = 0; k < this.numTopics; ++k) {
                double rik = this.itemTopicProbs.get(i, k);
                double pi_k = this.topicInitialProbs.get(k);
                double sum_nl = 0.0;
                for (int r = 0; r < this.numRatingLevels; ++r) {
                    double nir = this.itemNumEachRating.get(i, r);
                    double pkr = this.topicRatingProbs.get(k, r);
                    sum_nl += nir * Math.log(pkr);
                }
                loss += rik * (Math.log(pi_k) + sum_nl);
            }
        }
        float deltaLoss = (float)(loss - this.lastLoss);
        if (iter > 1 && (deltaLoss > 0.0f || Double.isNaN(deltaLoss))) {
            return true;
        }
        this.lastLoss = loss;
        return false;
    }

    @Override
    protected double predict(int userIdx, int itemIdx) throws LibrecException {
        double pred = 0.0;
        for (int k = 0; k < this.numTopics; ++k) {
            double pi_k = this.itemTopicProbs.get(itemIdx, k);
            double pred_k = 0.0;
            for (int r = 0; r < this.numRatingLevels; ++r) {
                double rij = (Double)ratingScale.get(r);
                double pkr = this.topicRatingProbs.get(k, r);
                pred_k += rij * pkr;
            }
            pred += pi_k * pred_k;
        }
        return pred;
    }
}

