/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.cf;

import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Gamma;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.DenseVector;
import net.librec.math.structure.MatrixEntry;
import net.librec.math.structure.VectorBasedDenseVector;
import net.librec.recommender.MatrixProbabilisticGraphicalRecommender;

public class BUCMRecommender
extends MatrixProbabilisticGraphicalRecommender {
    private int[][][] topicItemRatingNum;
    private DenseMatrix userTopicNum;
    private DenseVector userNum;
    private DenseMatrix topicItemNum;
    private DenseVector topicNum;
    private double[][][] topicItemRatingSumProbs;
    private double[][][] topicItemRatingProbs;
    private DenseMatrix userTopicProbs;
    private DenseMatrix userTopicSumProbs;
    private DenseMatrix topicItemProbs;
    private DenseMatrix topicItemSumProbs;
    private DenseVector alpha;
    private DenseVector beta;
    private DenseVector gamma;
    protected Table<Integer, Integer, Integer> topics;
    protected int numTopics;
    protected int numRatingLevels;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        this.numTopics = this.conf.getInt("rec.pgm.topic.number", 10);
        this.numRatingLevels = ratingScale.size();
        this.userTopicSumProbs = new DenseMatrix(this.numUsers, this.numTopics);
        this.topicItemSumProbs = new DenseMatrix(this.numTopics, this.numItems);
        this.topicItemRatingSumProbs = new double[this.numTopics][this.numItems][this.numRatingLevels];
        this.userTopicNum = new DenseMatrix(this.numUsers, this.numTopics);
        this.userNum = new VectorBasedDenseVector(this.numUsers);
        this.topicItemNum = new DenseMatrix(this.numTopics, this.numItems);
        this.topicNum = new VectorBasedDenseVector(this.numTopics);
        this.topicItemRatingNum = new int[this.numTopics][this.numItems][this.numRatingLevels];
        double initAlpha = this.conf.getDouble("rec.bucm.alpha", 1.0 / (double)this.numTopics);
        this.alpha = new VectorBasedDenseVector(this.numTopics);
        this.alpha.assign((index, value) -> initAlpha);
        double initBeta = this.conf.getDouble("re.bucm.beta", 1.0 / (double)this.numItems);
        this.beta = new VectorBasedDenseVector(this.numItems);
        this.beta.assign((index, value) -> initBeta);
        double initGamma = this.conf.getDouble("rec.bucm.gamma", 1.0 / (double)this.numTopics);
        this.gamma = new VectorBasedDenseVector(this.numRatingLevels);
        this.gamma.assign((index, value) -> initGamma);
        this.topics = HashBasedTable.create();
        for (MatrixEntry me : this.trainMatrix) {
            int u = me.row();
            int i = me.column();
            double rating = me.get();
            int r = ratingScale.indexOf(rating);
            int t = (int)(Randoms.uniform() * (double)this.numTopics);
            this.topics.put(u, i, t);
            this.userTopicNum.plus(u, t, 1.0);
            this.userNum.plus(u, 1.0);
            this.topicItemNum.plus(t, i, 1.0);
            this.topicNum.plus(t, 1.0);
            int[] nArray = this.topicItemRatingNum[t][i];
            int n = r;
            nArray[n] = nArray[n] + 1;
        }
    }

    @Override
    protected void eStep() {
        double sumAlpha = this.alpha.sum();
        double sumBeta = this.beta.sum();
        double sumGamma = this.gamma.sum();
        for (MatrixEntry me : this.trainMatrix) {
            int k;
            int u = me.row();
            int i = me.column();
            double rating = me.get();
            int r = ratingScale.indexOf(rating);
            int t = this.topics.get(u, i);
            this.userTopicNum.plus(u, t, -1.0);
            this.userNum.plus(u, -1.0);
            this.topicItemNum.plus(t, i, -1.0);
            this.topicNum.plus(t, -1.0);
            int[] nArray = this.topicItemRatingNum[t][i];
            int n = r;
            nArray[n] = nArray[n] - 1;
            double[] p = new double[this.numTopics];
            for (k = 0; k < this.numTopics; ++k) {
                double v1 = (this.userTopicNum.get(u, k) + this.alpha.get(k)) / (this.userNum.get(u) + sumAlpha);
                double v2 = (this.topicItemNum.get(k, i) + this.beta.get(i)) / (this.topicNum.get(k) + sumBeta);
                double v3 = ((double)this.topicItemRatingNum[k][i][r] + this.gamma.get(r)) / (this.topicItemNum.get(k, i) + sumGamma);
                p[k] = v1 * v2 * v3;
            }
            for (k = 1; k < this.numTopics; ++k) {
                int n2 = k;
                p[n2] = p[n2] + p[k - 1];
            }
            double rand = Randoms.uniform() * p[this.numTopics - 1];
            for (t = 0; t < this.numTopics && !(rand < p[t]); ++t) {
            }
            this.topics.put(u, i, t);
            this.userTopicNum.plus(u, t, 1.0);
            this.userNum.plus(u, 1.0);
            this.topicItemNum.plus(t, i, 1.0);
            this.topicNum.plus(t, 1.0);
            int[] nArray2 = this.topicItemRatingNum[t][i];
            int n3 = r;
            nArray2[n3] = nArray2[n3] + 1;
        }
    }

    @Override
    protected void mStep() {
        double denominator;
        double numerator;
        double sumAlpha = this.alpha.sum();
        double sumBeta = this.beta.sum();
        double sumGamma = this.gamma.sum();
        for (int k = 0; k < this.numTopics; ++k) {
            double ak = this.alpha.get(k);
            numerator = 0.0;
            denominator = 0.0;
            for (int u = 0; u < this.numUsers; ++u) {
                numerator += Gamma.digamma(this.userTopicNum.get(u, k) + ak) - Gamma.digamma(ak);
                denominator += Gamma.digamma(this.userNum.get(u) + sumAlpha) - Gamma.digamma(sumAlpha);
            }
            if (numerator == 0.0) continue;
            this.alpha.set(k, ak * (numerator / denominator));
        }
        for (int i = 0; i < this.numItems; ++i) {
            double bi = this.beta.get(i);
            numerator = 0.0;
            denominator = 0.0;
            for (int k = 0; k < this.numTopics; ++k) {
                numerator += Gamma.digamma(this.topicItemNum.get(k, i) + bi) - Gamma.digamma(bi);
                denominator += Gamma.digamma(this.topicNum.get(k) + sumBeta) - Gamma.digamma(sumBeta);
            }
            if (numerator == 0.0) continue;
            this.beta.set(i, bi * (numerator / denominator));
        }
        for (int r = 0; r < this.numRatingLevels; ++r) {
            double gr = this.gamma.get(r);
            numerator = 0.0;
            denominator = 0.0;
            for (int i = 0; i < this.numItems; ++i) {
                for (int k = 0; k < this.numTopics; ++k) {
                    numerator += Gamma.digamma((double)this.topicItemRatingNum[k][i][r] + gr) - Gamma.digamma(gr);
                    denominator += Gamma.digamma(this.topicItemNum.get(k, i) + sumGamma) - Gamma.digamma(sumGamma);
                }
            }
            if (numerator == 0.0) continue;
            this.gamma.set(r, gr * (numerator / denominator));
        }
    }

    @Override
    protected boolean isConverged(int iter) {
        double loss = 0.0;
        this.estimateParams();
        int count = 0;
        for (MatrixEntry me : this.trainMatrix) {
            int u = me.row();
            int i = me.column();
            double rui = me.get();
            int r = ratingScale.indexOf(rui);
            double prob = 0.0;
            for (int k = 0; k < this.numTopics; ++k) {
                prob += this.userTopicProbs.get(u, k) * this.topicItemProbs.get(k, i) * this.topicItemRatingProbs[k][i][r];
            }
            loss += -Math.log(prob);
            ++count;
        }
        double delta = (loss /= (double)count) - this.lastLoss;
        if (this.numStats > 1 && delta > 0.0) {
            return true;
        }
        this.lastLoss = loss;
        return false;
    }

    @Override
    protected void readoutParams() {
        int i;
        int k;
        double val;
        double sumAlpha = this.alpha.sum();
        double sumBeta = this.beta.sum();
        double sumGamma = this.gamma.sum();
        for (int u = 0; u < this.numUsers; ++u) {
            for (int k2 = 0; k2 < this.numTopics; ++k2) {
                val = (this.userTopicNum.get(u, k2) + this.alpha.get(k2)) / (this.userNum.get(u) + sumAlpha);
                this.userTopicSumProbs.plus(u, k2, val);
            }
        }
        for (k = 0; k < this.numTopics; ++k) {
            for (i = 0; i < this.numItems; ++i) {
                val = (this.topicItemNum.get(k, i) + this.beta.get(i)) / (this.topicNum.get(k) + sumBeta);
                this.topicItemSumProbs.plus(k, i, val);
            }
        }
        for (k = 0; k < this.numTopics; ++k) {
            for (i = 0; i < this.numItems; ++i) {
                int r = 0;
                while (r < this.numRatingLevels) {
                    val = ((double)this.topicItemRatingNum[k][i][r] + this.gamma.get(r)) / (this.topicItemNum.get(k, i) + sumGamma);
                    double[] dArray = this.topicItemRatingSumProbs[k][i];
                    int n = r++;
                    dArray[n] = dArray[n] + val;
                }
            }
        }
        ++this.numStats;
    }

    @Override
    protected void estimateParams() {
        this.userTopicSumProbs.assign((row, column, value) -> value * (1.0 / (double)this.numStats));
        this.topicItemSumProbs.assign((row, column, value) -> value * (1.0 / (double)this.numStats));
        this.userTopicProbs = this.userTopicSumProbs.clone();
        this.topicItemProbs = this.topicItemSumProbs.clone();
        this.topicItemRatingProbs = new double[this.numTopics][this.numItems][this.numRatingLevels];
        for (int k = 0; k < this.numTopics; ++k) {
            for (int i = 0; i < this.numItems; ++i) {
                for (int r = 0; r < this.numRatingLevels; ++r) {
                    this.topicItemRatingProbs[k][i][r] = this.topicItemRatingSumProbs[k][i][r] / (double)this.numStats;
                }
            }
        }
    }

    protected double perplexity(int u, int j, double ruj) throws Exception {
        int r = (int)(ruj / this.minRate) - 1;
        double prob = 0.0;
        for (int k = 0; k < this.numTopics; ++k) {
            prob += this.userTopicProbs.get(u, k) * this.topicItemProbs.get(k, j) * this.topicItemRatingProbs[k][j][r];
        }
        return -Math.log(prob);
    }

    @Override
    protected double predict(int userIdx, int itemIdx) throws LibrecException {
        if (this.isRanking) {
            return this.predictRanking(userIdx, itemIdx);
        }
        return this.predictRating(userIdx, itemIdx);
    }

    protected double predictRating(int userIdx, int itemIdx) {
        double pred = 0.0;
        double probs = 0.0;
        for (int r = 0; r < this.numRatingLevels; ++r) {
            double rate = (Double)ratingScale.get(r);
            double prob = 0.0;
            for (int k = 0; k < this.numTopics; ++k) {
                prob += this.userTopicProbs.get(userIdx, k) * this.topicItemProbs.get(k, itemIdx) * this.topicItemRatingProbs[k][itemIdx][r];
            }
            pred += prob * rate;
            probs += prob;
        }
        return pred / probs;
    }

    protected double predictRanking(int userIdx, int itemIdx) {
        double rankScore = 0.0;
        for (int topicIdx = 0; topicIdx < this.numTopics; ++topicIdx) {
            double sum2 = 0.0;
            for (int rateIdx = 0; rateIdx < this.numRatingLevels; ++rateIdx) {
                double rate = (Double)ratingScale.get(rateIdx);
                if (!(rate > this.globalMean)) continue;
                sum2 += this.topicItemRatingProbs[topicIdx][itemIdx][rateIdx];
            }
            rankScore += this.userTopicProbs.get(userIdx, topicIdx) * this.topicItemProbs.get(topicIdx, itemIdx) * sum2;
        }
        return rankScore;
    }
}

