/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.cf.ranking;

import net.librec.common.LibrecException;
import net.librec.math.algorithm.Gamma;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.MatrixEntry;
import net.librec.math.structure.VectorBasedDenseVector;
import net.librec.recommender.MatrixFactorizationRecommender;
import org.apache.commons.math3.distribution.GammaDistribution;

public class BPoissMFRecommender
extends MatrixFactorizationRecommender {
    private double a;
    private double aPrime;
    private double bPrime;
    private double c;
    private double cPrime;
    private double dPrime;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        this.a = this.conf.getDouble("rec.a", 0.3);
        this.aPrime = this.conf.getDouble("rec.a.prime", 0.3);
        this.bPrime = this.conf.getDouble("rec.b.prime", 1.0);
        this.c = this.conf.getDouble("rec.c", 0.3);
        this.cPrime = this.conf.getDouble("rec.c.prime", 0.3);
        this.dPrime = this.conf.getDouble("rec.d.prime", 1.0);
        GammaDistribution userGammaDis = new GammaDistribution(this.a, 1.0 / this.bPrime);
        for (int u = 0; u < this.numUsers; ++u) {
            for (int k = 0; k < this.numFactors; ++k) {
                this.userFactors.set(u, k, userGammaDis.sample());
            }
        }
        GammaDistribution itemGammaDis = new GammaDistribution(this.c, 1.0 / this.dPrime);
        for (int i = 0; i < this.numItems; ++i) {
            for (int k = 0; k < this.numFactors; ++k) {
                this.itemFactors.set(i, k, itemGammaDis.sample());
            }
        }
    }

    @Override
    protected void trainModel() throws LibrecException {
        int k;
        double kShp = this.aPrime + (double)this.numFactors * this.a;
        double tShp = this.cPrime + (double)this.numFactors * this.c;
        VectorBasedDenseVector kRte = new VectorBasedDenseVector(this.numUsers);
        for (int u = 0; u < this.numUsers; ++u) {
            kRte.set(u, this.userFactors.row(u).sum() + this.bPrime);
        }
        VectorBasedDenseVector tRte = new VectorBasedDenseVector(this.numItems);
        for (int i = 0; i < this.numItems; ++i) {
            tRte.set(i, this.itemFactors.row(i).sum() + this.dPrime);
        }
        GammaDistribution gammaPre = new GammaDistribution(this.aPrime, this.bPrime / this.aPrime);
        DenseMatrix gammaRte = new DenseMatrix(this.numUsers, this.numFactors);
        for (int u = 0; u < this.numUsers; ++u) {
            double userSample = gammaPre.sample();
            for (int k2 = 0; k2 < this.numFactors; ++k2) {
                gammaRte.set(u, k2, userSample + this.itemFactors.column(k2).sum());
            }
        }
        GammaDistribution lambdaPre = new GammaDistribution(this.cPrime, this.dPrime / this.cPrime);
        DenseMatrix lambdaRte = new DenseMatrix(this.numItems, this.numFactors);
        for (int i = 0; i < this.numItems; ++i) {
            double itemSample = lambdaPre.sample();
            for (k = 0; k < this.numFactors; ++k) {
                lambdaRte.set(i, k, itemSample + this.userFactors.column(k).sum());
            }
        }
        DenseMatrix gammaShp = new DenseMatrix(this.numUsers, this.numFactors);
        for (int u = 0; u < this.numUsers; ++u) {
            for (int k3 = 0; k3 < this.numFactors; ++k3) {
                gammaShp.set(u, k3, gammaRte.get(u, k3) * this.userFactors.get(u, k3) * Randoms.uniform(0.85, 1.15));
            }
        }
        DenseMatrix lambdaShp = new DenseMatrix(this.numItems, this.numFactors);
        for (int i = 0; i < this.numItems; ++i) {
            for (k = 0; k < this.numFactors; ++k) {
                lambdaShp.set(i, k, lambdaRte.get(i, k) * this.itemFactors.get(i, k) * Randoms.uniform(0.85, 1.15));
            }
        }
        DenseMatrix phi = new DenseMatrix(this.numRates, this.numFactors);
        double addKRate = this.aPrime / this.bPrime;
        double addTRate = this.cPrime / this.dPrime;
        for (int iter = 1; iter <= this.numIterations; ++iter) {
            int i;
            double colNumner;
            int k4;
            int u;
            this.updatePhi(gammaShp, gammaRte, lambdaShp, lambdaRte, phi);
            for (u = 0; u < this.numUsers; ++u) {
                double rowNumner = kShp / kRte.get(u);
                for (k4 = 0; k4 < this.numFactors; ++k4) {
                    colNumner = this.itemFactors.column(k4).sum();
                    gammaRte.set(u, k4, rowNumner + colNumner);
                }
            }
            for (u = 0; u < this.numUsers; ++u) {
                for (int k5 = 0; k5 < this.numFactors; ++k5) {
                    gammaShp.set(u, k5, 0.0);
                }
            }
            for (i = 0; i < this.numItems; ++i) {
                for (int k6 = 0; k6 < this.numFactors; ++k6) {
                    lambdaShp.set(i, k6, 0.0);
                }
            }
            this.update_G_n_L_sh(gammaShp, lambdaShp, phi, this.a, this.c);
            for (u = 0; u < this.numUsers; ++u) {
                for (int k7 = 0; k7 < this.numFactors; ++k7) {
                    this.userFactors.set(u, k7, gammaShp.get(u, k7) / gammaRte.get(u, k7));
                }
            }
            for (i = 0; i < this.numItems; ++i) {
                double rowNumner = tShp / tRte.get(i);
                for (k4 = 0; k4 < this.numFactors; ++k4) {
                    colNumner = this.userFactors.column(k4).sum();
                    lambdaRte.set(i, k4, rowNumner + colNumner);
                }
            }
            for (i = 0; i < this.numItems; ++i) {
                for (int k8 = 0; k8 < this.numFactors; ++k8) {
                    this.itemFactors.set(i, k8, lambdaShp.get(i, k8) / lambdaRte.get(i, k8));
                }
            }
            for (u = 0; u < this.numUsers; ++u) {
                double newValue = addKRate + this.userFactors.row(u).sum();
                kRte.set(u, newValue);
            }
            for (i = 0; i < this.numItems; ++i) {
                double newValue = addTRate + this.itemFactors.row(i).sum();
                tRte.set(i, newValue);
            }
        }
    }

    protected void updatePhi(DenseMatrix gammaShp, DenseMatrix gammaRte, DenseMatrix lambdaShp, DenseMatrix lambdaRte, DenseMatrix phi) {
        int ratingCount = 0;
        for (MatrixEntry me : this.trainMatrix) {
            double newPhiValue;
            int k;
            int userIdx = me.row();
            int itemIdx = me.column();
            double rating = me.get();
            double sumphi = 0.0;
            for (k = 0; k < this.numFactors; ++k) {
                newPhiValue = Math.exp(Gamma.digamma(gammaShp.get(userIdx, k)) - Math.log(gammaRte.get(userIdx, k)) + Gamma.digamma(lambdaShp.get(itemIdx, k)) - Math.log(lambdaRte.get(itemIdx, k)));
                phi.set(ratingCount, k, newPhiValue);
                sumphi += newPhiValue;
            }
            for (k = 0; k < this.numFactors; ++k) {
                newPhiValue = phi.get(ratingCount, k) * rating / sumphi;
                phi.set(ratingCount, k, newPhiValue);
            }
            ++ratingCount;
        }
    }

    protected void update_G_n_L_sh(DenseMatrix gammaShp, DenseMatrix lambdaShp, DenseMatrix phi, double a, double c) {
        int ratingCount = 0;
        for (MatrixEntry me : this.trainMatrix) {
            int userIdx = me.row();
            int itemIdx = me.column();
            for (int k = 0; k < this.numFactors; ++k) {
                double phiValue = phi.get(ratingCount, k);
                double newGammaValue = gammaShp.get(userIdx, k) + phiValue + a;
                gammaShp.set(userIdx, k, newGammaValue);
                double newLambdaValue = lambdaShp.get(itemIdx, k) + phiValue + c;
                lambdaShp.set(itemIdx, k, newLambdaValue);
            }
            ++ratingCount;
        }
    }
}

