/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.cf.ranking;

import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Gamma;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.MatrixEntry;
import net.librec.math.structure.VectorBasedDenseVector;
import net.librec.recommender.MatrixProbabilisticGraphicalRecommender;

@ModelData(value={"isRanking", "lda", "userTopicProbs", "topicItemProbs", "trainMatrix"})
public class LDARecommender
extends MatrixProbabilisticGraphicalRecommender {
    protected float initAlpha;
    protected float initBeta;
    protected DenseMatrix topicItemNumbers;
    protected DenseMatrix userTopicNumbers;
    protected int[] topicAssignments;
    protected VectorBasedDenseVector userTokenNumbers;
    protected VectorBasedDenseVector topicTokenNumbers;
    protected int numTopics;
    protected VectorBasedDenseVector alpha;
    protected VectorBasedDenseVector beta;
    protected DenseMatrix userTopicProbsSum;
    protected DenseMatrix topicItemProbsSum;
    protected DenseMatrix userTopicProbs;
    protected DenseMatrix topicItemProbs;
    protected int numStats = 0;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        this.numTopics = this.conf.getInt("rec.topic.number", 10);
        this.userTopicProbsSum = new DenseMatrix(this.numUsers, this.numTopics);
        this.topicItemProbsSum = new DenseMatrix(this.numTopics, this.numItems);
        this.userTopicNumbers = new DenseMatrix(this.numUsers, this.numTopics);
        this.userTokenNumbers = new VectorBasedDenseVector(this.numUsers);
        this.topicItemNumbers = new DenseMatrix(this.numTopics, this.numItems);
        this.topicTokenNumbers = new VectorBasedDenseVector(this.numTopics);
        this.initAlpha = this.conf.getFloat("rec.user.dirichlet.prior", Float.valueOf(50.0f / (float)this.numTopics)).floatValue();
        this.initBeta = this.conf.getFloat("rec.topic.dirichlet.prior", Float.valueOf(0.01f)).floatValue();
        this.alpha = new VectorBasedDenseVector(this.numTopics);
        this.alpha.assign((index, value) -> this.initAlpha);
        this.beta = new VectorBasedDenseVector(this.numItems);
        this.beta.assign((index, value) -> this.initBeta);
        this.topicAssignments = new int[this.trainMatrix.size()];
        int topicAssignmentsIndex = 0;
        for (MatrixEntry matrixEntry : this.trainMatrix) {
            int userIdx = matrixEntry.row();
            int itemIdx = matrixEntry.column();
            int num = (int)matrixEntry.get();
            for (int numIdx = 0; numIdx < num; ++numIdx) {
                int topicIdx = Randoms.uniform(this.numTopics);
                this.topicAssignments[topicAssignmentsIndex++] = topicIdx;
                this.userTopicNumbers.plus(userIdx, topicIdx, 1.0);
                this.userTokenNumbers.plus(userIdx, 1.0);
                this.topicItemNumbers.plus(topicIdx, itemIdx, 1.0);
                this.topicTokenNumbers.plus(topicIdx, 1.0);
            }
        }
    }

    @Override
    protected void eStep() {
        double sumAlpha = this.alpha.sum();
        double sumBeta = this.beta.sum();
        int topicAssignmentsIdx = 0;
        for (MatrixEntry matrixEntry : this.trainMatrix) {
            int userIdx = matrixEntry.row();
            int itemIdx = matrixEntry.column();
            int num = (int)matrixEntry.get();
            for (int numIdx = 0; numIdx < num; ++numIdx) {
                int topicIdx = this.topicAssignments[topicAssignmentsIdx];
                this.userTopicNumbers.plus(userIdx, topicIdx, -1.0);
                this.userTokenNumbers.plus(userIdx, -1.0);
                this.topicItemNumbers.plus(topicIdx, itemIdx, -1.0);
                this.topicTokenNumbers.plus(topicIdx, -1.0);
                double[] p = new double[this.numTopics];
                for (topicIdx = 0; topicIdx < this.numTopics; ++topicIdx) {
                    p[topicIdx] = (this.userTopicNumbers.get(userIdx, topicIdx) + this.alpha.get(topicIdx)) / (this.userTokenNumbers.get(userIdx) + sumAlpha) * (this.topicItemNumbers.get(topicIdx, itemIdx) + this.beta.get(itemIdx)) / (this.topicTokenNumbers.get(topicIdx) + sumBeta);
                }
                for (topicIdx = 1; topicIdx < p.length; ++topicIdx) {
                    int n = topicIdx;
                    p[n] = p[n] + p[topicIdx - 1];
                }
                double rand = Randoms.uniform() * p[this.numTopics - 1];
                for (topicIdx = 0; topicIdx < p.length && !(rand < p[topicIdx]); ++topicIdx) {
                }
                this.userTopicNumbers.plus(userIdx, topicIdx, 1.0);
                this.userTokenNumbers.plus(userIdx, 1.0);
                this.topicItemNumbers.plus(topicIdx, itemIdx, 1.0);
                this.topicTokenNumbers.plus(topicIdx, 1.0);
                this.topicAssignments[topicAssignmentsIdx] = topicIdx;
                ++topicAssignmentsIdx;
            }
        }
    }

    @Override
    protected void mStep() {
        double sumAlpha = this.alpha.sum();
        double sumBeta = this.beta.sum();
        double digammaAlphaSum = Gamma.digamma(sumAlpha);
        double denominator = 0.0;
        for (int userIdx = 0; userIdx < this.numUsers; ++userIdx) {
            denominator += Gamma.digamma(this.userTokenNumbers.get(userIdx) + sumAlpha) - digammaAlphaSum;
        }
        for (int topicIdx = 0; topicIdx < this.numTopics; ++topicIdx) {
            double topicAlpha = this.alpha.get(topicIdx);
            double digammaTopicAlpha = Gamma.digamma(topicAlpha);
            double numerator = 0.0;
            for (int userIdx = 0; userIdx < this.numUsers; ++userIdx) {
                numerator += Gamma.digamma(this.userTopicNumbers.get(userIdx, topicIdx) + topicAlpha) - digammaTopicAlpha;
            }
            if (numerator == 0.0) continue;
            this.alpha.set(topicIdx, topicAlpha * (numerator / denominator));
        }
        denominator = 0.0;
        double digammaBetaSum = Gamma.digamma(sumBeta);
        for (int topicIdx = 0; topicIdx < this.numTopics; ++topicIdx) {
            denominator += Gamma.digamma(this.topicTokenNumbers.get(topicIdx) + sumBeta) - digammaBetaSum;
        }
        for (int itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
            double itemBeta = this.beta.get(itemIdx);
            double digammaItemBeta = Gamma.digamma(itemBeta);
            double numerator = 0.0;
            for (int topicIdx = 0; topicIdx < this.numTopics; ++topicIdx) {
                numerator += Gamma.digamma(this.topicItemNumbers.get(topicIdx, itemIdx) + itemBeta) - digammaItemBeta;
            }
            if (numerator == 0.0) continue;
            this.beta.set(itemIdx, itemBeta * (numerator / denominator));
        }
    }

    @Override
    protected void readoutParams() {
        double val;
        double sumAlpha = this.alpha.sum();
        double sumBeta = this.beta.sum();
        for (int userIdx = 0; userIdx < this.numUsers; ++userIdx) {
            for (int factorIdx = 0; factorIdx < this.numTopics; ++factorIdx) {
                val = (this.userTopicNumbers.get(userIdx, factorIdx) + this.alpha.get(factorIdx)) / (this.userTokenNumbers.get(userIdx) + sumAlpha);
                this.userTopicProbsSum.plus(userIdx, factorIdx, val);
            }
        }
        for (int factorIdx = 0; factorIdx < this.numTopics; ++factorIdx) {
            for (int itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
                val = (this.topicItemNumbers.get(factorIdx, itemIdx) + this.beta.get(itemIdx)) / (this.topicTokenNumbers.get(factorIdx) + sumBeta);
                this.topicItemProbsSum.plus(factorIdx, itemIdx, val);
            }
        }
        ++this.numStats;
    }

    @Override
    protected void estimateParams() {
        this.userTopicProbs = this.userTopicProbsSum.times(1.0 / (double)this.numStats);
        this.topicItemProbs = this.topicItemProbsSum.times(1.0 / (double)this.numStats);
    }

    @Override
    protected double predict(int userIdx, int itemIdx) throws LibrecException {
        return this.userTopicProbs.row(userIdx).dot(this.topicItemProbs.column(itemIdx));
    }
}

