/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.cf;

import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.DenseVector;
import net.librec.math.structure.MatrixEntry;
import net.librec.math.structure.VectorBasedDenseVector;
import net.librec.recommender.MatrixProbabilisticGraphicalRecommender;

public class BHFreeRecommender
extends MatrixProbabilisticGraphicalRecommender {
    private float initGamma;
    private float initSigma;
    private float initAlpha;
    private float initBeta;
    private int numUserTopics;
    private int numItemTopics;
    private DenseMatrix userTopicNum;
    private DenseVector userNum;
    private DenseVector uTopicNum;
    private DenseMatrix userTopicItemTopicNum;
    private int[][][] userTopicItemTopicRatingNum;
    private int[][][] userTopicItemTopicItemNum;
    private Table<Integer, Integer, Integer> userTopics;
    private Table<Integer, Integer, Integer> itemTopics;
    private int numRatingLevels;
    private DenseMatrix userTopicProbs;
    private DenseMatrix userTopicItemTopicProbs;
    private DenseMatrix userTopicSumProbs;
    private DenseMatrix userTopicItemTopicSumProbs;
    private double[][][] userTopicItemTopicRatingProbs;
    private double[][][] userTopicItemTopicItemProbs;
    private double[][][] userTopicItemTopicRatingSumProbs;
    private double[][][] userTopicItemTopicItemSumProbs;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        this.numUserTopics = this.conf.getInt("rec.bhfree.user.topic.number", 10);
        this.numItemTopics = this.conf.getInt("rec.bhfree.item.topic.number", 10);
        this.initAlpha = this.conf.getFloat("rec.bhfree.alpha", Float.valueOf(1.0f / (float)this.numUserTopics)).floatValue();
        this.initBeta = this.conf.getFloat("rec.bhfree.beta", Float.valueOf(1.0f / (float)this.numItemTopics)).floatValue();
        this.initGamma = this.conf.getFloat("rec.bhfree.gamma", Float.valueOf(1.0f / (float)this.numRatingLevels)).floatValue();
        this.initSigma = this.conf.getFloat("rec.sigma", Float.valueOf(1.0f / (float)this.numItems)).floatValue();
        this.numRatingLevels = ratingScale.size();
        this.userTopicNum = new DenseMatrix(this.numUsers, this.numUserTopics);
        this.userNum = new VectorBasedDenseVector(this.numUsers);
        this.userTopicItemTopicNum = new DenseMatrix(this.numUserTopics, this.numItemTopics);
        this.uTopicNum = new VectorBasedDenseVector(this.numUserTopics);
        this.userTopicItemTopicRatingNum = new int[this.numUserTopics][this.numItemTopics][this.numRatingLevels];
        this.userTopicItemTopicItemNum = new int[this.numUserTopics][this.numItemTopics][this.numItems];
        this.userTopics = HashBasedTable.create();
        this.itemTopics = HashBasedTable.create();
        for (MatrixEntry me : this.trainMatrix) {
            int u = me.row();
            int i = me.column();
            double rate = me.get();
            int r = ratingScale.indexOf(rate);
            int k = (int)((double)this.numUserTopics * Randoms.uniform());
            int l = (int)((double)this.numItemTopics * Randoms.uniform());
            this.userTopicNum.plus(u, k, 1.0);
            this.userNum.plus(u, 1.0);
            this.userTopicItemTopicNum.plus(k, l, 1.0);
            this.uTopicNum.plus(k, 1.0);
            int[] nArray = this.userTopicItemTopicRatingNum[k][l];
            int n = r;
            nArray[n] = nArray[n] + 1;
            int[] nArray2 = this.userTopicItemTopicItemNum[k][l];
            int n2 = i;
            nArray2[n2] = nArray2[n2] + 1;
            this.userTopics.put(u, i, k);
            this.itemTopics.put(u, i, l);
        }
        this.userTopicSumProbs = new DenseMatrix(this.numUsers, this.numUserTopics);
        this.userTopicItemTopicSumProbs = new DenseMatrix(this.numUserTopics, this.numItemTopics);
        this.userTopicItemTopicRatingSumProbs = new double[this.numUserTopics][this.numItemTopics][this.numRatingLevels];
        this.userTopicItemTopicRatingProbs = new double[this.numUserTopics][this.numItemTopics][this.numRatingLevels];
        this.userTopicItemTopicItemSumProbs = new double[this.numUserTopics][this.numItemTopics][this.numItems];
        this.userTopicItemTopicItemProbs = new double[this.numUserTopics][this.numItemTopics][this.numItems];
    }

    @Override
    protected void eStep() {
        for (MatrixEntry me : this.trainMatrix) {
            int w;
            int z;
            int u = me.row();
            int i = me.column();
            double rate = me.get();
            int r = ratingScale.indexOf(rate);
            int k = this.userTopics.get(u, i);
            int l = this.itemTopics.get(u, i);
            this.userTopicNum.plus(u, k, -1.0);
            this.userNum.plus(u, -1.0);
            this.userTopicItemTopicNum.plus(k, l, -1.0);
            this.uTopicNum.plus(k, -1.0);
            int[] nArray = this.userTopicItemTopicRatingNum[k][l];
            int n = r;
            nArray[n] = nArray[n] - 1;
            int[] nArray2 = this.userTopicItemTopicItemNum[k][l];
            int n2 = i;
            nArray2[n2] = nArray2[n2] - 1;
            DenseMatrix userTopicItemTopicProbs = new DenseMatrix(this.numUserTopics, this.numItemTopics);
            double sum2 = 0.0;
            for (int z2 = 0; z2 < this.numUserTopics; ++z2) {
                for (int w2 = 0; w2 < this.numItemTopics; ++w2) {
                    double v1 = (this.userTopicNum.get(u, k) + (double)this.initAlpha) / (this.userNum.get(u) + (double)((float)this.numUserTopics * this.initAlpha));
                    double v2 = (this.userTopicItemTopicNum.get(k, l) + (double)this.initBeta) / (this.uTopicNum.get(k) + (double)((float)this.numItemTopics * this.initBeta));
                    double v3 = (double)((float)this.userTopicItemTopicRatingNum[k][l][r] + this.initGamma) / (this.userTopicItemTopicNum.get(k, l) + (double)((float)this.numRatingLevels * this.initGamma));
                    double v4 = (double)((float)this.userTopicItemTopicItemNum[k][l][i] + this.initSigma) / (this.userTopicItemTopicNum.get(k, l) + (double)((float)this.numItems * this.initSigma));
                    double val = v1 * v2 * v3 * v4;
                    userTopicItemTopicProbs.set(z2, w2, val);
                    sum2 += val;
                }
            }
            double tmpSum = sum2;
            userTopicItemTopicProbs.assign((row, column, value) -> value * (1.0 / tmpSum));
            double[] userTopicProbs = new double[this.numUserTopics];
            for (z = 0; z < this.numUserTopics; ++z) {
                userTopicProbs[z] = userTopicItemTopicProbs.row(z).sum();
            }
            for (z = 1; z < this.numUserTopics; ++z) {
                int n3 = z;
                userTopicProbs[n3] = userTopicProbs[n3] + userTopicProbs[z - 1];
            }
            double rand = Randoms.uniform();
            for (k = 0; k < this.numUserTopics - 1 && !(rand < userTopicProbs[k]); ++k) {
            }
            double[] itemTopicProbs = new double[this.numItemTopics];
            for (w = 0; w < this.numItemTopics; ++w) {
                itemTopicProbs[w] = userTopicItemTopicProbs.column(w).sum();
            }
            for (w = 1; w < this.numItemTopics; ++w) {
                int n4 = w;
                itemTopicProbs[n4] = itemTopicProbs[n4] + itemTopicProbs[w - 1];
            }
            rand = Randoms.uniform();
            for (l = 0; l < this.numItemTopics - 1 && !(rand < itemTopicProbs[l]); ++l) {
            }
            this.userTopicNum.plus(u, k, 1.0);
            this.userNum.plus(u, 1.0);
            this.userTopicItemTopicNum.plus(k, l, 1.0);
            this.uTopicNum.plus(k, 1.0);
            int[] nArray3 = this.userTopicItemTopicRatingNum[k][l];
            int n5 = r;
            nArray3[n5] = nArray3[n5] + 1;
            int[] nArray4 = this.userTopicItemTopicItemNum[k][l];
            int n6 = i;
            nArray4[n6] = nArray4[n6] + 1;
            this.userTopics.put(u, i, k);
            this.itemTopics.put(u, i, l);
        }
    }

    @Override
    protected void mStep() {
    }

    @Override
    protected void readoutParams() {
        int l;
        int k;
        for (int u = 0; u < this.numUsers; ++u) {
            for (int k2 = 0; k2 < this.numUserTopics; ++k2) {
                this.userTopicSumProbs.plus(u, k2, (this.userTopicNum.get(u, k2) + (double)this.initAlpha) / (this.userNum.get(u) + (double)((float)this.numUserTopics * this.initAlpha)));
            }
        }
        for (k = 0; k < this.numUserTopics; ++k) {
            for (l = 0; l < this.numItemTopics; ++l) {
                this.userTopicItemTopicSumProbs.plus(k, l, (this.userTopicItemTopicNum.get(k, l) + (double)this.initBeta) / (this.uTopicNum.get(k) + (double)((float)this.numItemTopics * this.initBeta)));
            }
        }
        for (k = 0; k < this.numUserTopics; ++k) {
            for (l = 0; l < this.numItemTopics; ++l) {
                for (int r = 0; r < this.numRatingLevels; ++r) {
                    double[] dArray = this.userTopicItemTopicRatingSumProbs[k][l];
                    int n = r;
                    dArray[n] = dArray[n] + (double)((float)this.userTopicItemTopicRatingNum[k][l][r] + this.initGamma) / (this.userTopicItemTopicNum.get(k, l) + (double)((float)this.numRatingLevels * this.initGamma));
                }
            }
        }
        for (k = 0; k < this.numUserTopics; ++k) {
            for (l = 0; l < this.numItemTopics; ++l) {
                for (int i = 0; i < this.numItems; ++i) {
                    double[] dArray = this.userTopicItemTopicItemSumProbs[k][l];
                    int n = i;
                    dArray[n] = dArray[n] + (double)((float)this.userTopicItemTopicItemNum[k][l][i] + this.initSigma) / (this.userTopicItemTopicNum.get(k, l) + (double)((float)this.numItems * this.initSigma));
                }
            }
        }
        ++this.numStats;
    }

    @Override
    protected void estimateParams() {
        int l;
        int k;
        double scale = 1.0 / (double)this.numStats;
        this.userTopicSumProbs.assign((row, column, value) -> value * scale);
        this.userTopicItemTopicSumProbs.assign((row, column, value) -> value * scale);
        this.userTopicProbs = this.userTopicSumProbs.clone();
        this.userTopicItemTopicProbs = this.userTopicItemTopicSumProbs.clone();
        for (k = 0; k < this.numUserTopics; ++k) {
            for (l = 0; l < this.numItemTopics; ++l) {
                for (int r = 0; r < this.numRatingLevels; ++r) {
                    this.userTopicItemTopicRatingProbs[k][l][r] = this.userTopicItemTopicRatingSumProbs[k][l][r] * scale;
                }
            }
        }
        for (k = 0; k < this.numUserTopics; ++k) {
            for (l = 0; l < this.numItemTopics; ++l) {
                for (int i = 0; i < this.numItems; ++i) {
                    this.userTopicItemTopicItemProbs[k][l][i] = this.userTopicItemTopicItemSumProbs[k][l][i] * scale;
                }
            }
        }
    }

    @Override
    protected double predict(int userIdx, int itemIdx) throws LibrecException {
        if (this.isRanking) {
            return this.predictRanking(userIdx, itemIdx);
        }
        return this.predictRating(userIdx, itemIdx);
    }

    protected double predictRating(int userIdx, int itemIdx) {
        double sum2 = 0.0;
        double probs = 0.0;
        for (int r = 0; r < this.numRatingLevels; ++r) {
            double rate = (Double)ratingScale.get(r);
            double prob = 0.0;
            for (int k = 0; k < this.numUserTopics; ++k) {
                for (int l = 0; l < this.numItemTopics; ++l) {
                    prob += this.userTopicProbs.get(userIdx, k) * this.userTopicItemTopicProbs.get(k, l) * this.userTopicItemTopicRatingProbs[k][l][r];
                }
            }
            sum2 += rate * prob;
            probs += prob;
        }
        return sum2 / probs;
    }

    protected double predictRanking(int userIdx, int itemIdx) {
        double rank = 0.0;
        for (int r = 0; r < this.numRatingLevels; ++r) {
            double rate = (Double)ratingScale.get(r);
            double prob = 0.0;
            for (int k = 0; k < this.numUserTopics; ++k) {
                for (int l = 0; l < this.numItemTopics; ++l) {
                    prob += this.userTopicProbs.get(userIdx, k) * this.userTopicItemTopicProbs.get(k, l) * this.userTopicItemTopicItemSumProbs[k][l][itemIdx] * this.userTopicItemTopicRatingProbs[k][l][r];
                }
            }
            rank += rate * prob;
        }
        return rank;
    }
}

