/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.cf.ranking;

import net.librec.common.LibrecException;
import net.librec.math.algorithm.Gamma;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.DenseVector;
import net.librec.math.structure.SequentialSparseVector;
import net.librec.math.structure.Vector;
import net.librec.math.structure.VectorBasedDenseVector;
import net.librec.recommender.MatrixFactorizationRecommender;

public class BNPPFRecommeder
extends MatrixFactorizationRecommender {
    private double alpha;
    private double c;
    private double a;
    private double b;
    private DenseMatrix v;
    private DenseMatrix pi;
    private DenseMatrix logpi;
    private GammaDenseMatrixGR beta;
    private GammaDenseVector s;
    private DenseVector eBetaSum;
    private DenseVector eThetaSum;
    private DenseMatrix zUsers;
    private DenseMatrix zItems;
    private DenseVector userBudget;
    double d_scalar;

    @Override
    protected void setup() throws LibrecException {
        int u;
        super.setup();
        this.alpha = this.conf.getDouble("rec.alpha", 0.3);
        this.c = this.conf.getDouble("rec.c", 0.3);
        this.a = this.conf.getDouble("rec.a", 0.3);
        this.b = this.conf.getDouble("rec.b", 0.3);
        this.beta = new GammaDenseMatrixGR(this.numItems, this.numFactors, this.a, this.b);
        this.beta.init();
        this.s = new GammaDenseVector(this.numUsers, this.alpha, this.c);
        this.s.init();
        this.v = new DenseMatrix(this.numUsers, this.numFactors);
        this.v.init(0.001);
        this.pi = new DenseMatrix(this.numUsers, this.numFactors);
        this.logpi = new DenseMatrix(this.numUsers, this.numFactors);
        for (u = 0; u < this.numUsers; ++u) {
            double lw = 0.0;
            for (int k = 0; k < this.numFactors; ++k) {
                double v_value = this.v.get(u, k);
                if (v_value < 1.0E-30) {
                    v_value = 1.0E-30;
                    this.v.set(u, k, 1.0E-30);
                }
                if (k > 0) {
                    lw += Math.log(1.0 - this.v.get(u, k - 1));
                }
                this.logpi.set(u, k, Math.log(v_value) + lw);
                double pi_value = Math.exp(v_value);
                this.pi.set(u, k, pi_value);
                this.userFactors.set(u, k, pi_value * this.s.value.get(u));
            }
        }
        this.eBetaSum = new VectorBasedDenseVector(this.numFactors);
        this.eThetaSum = new VectorBasedDenseVector(this.numFactors);
        this.zUsers = new DenseMatrix(this.numUsers, this.numFactors);
        this.zItems = new DenseMatrix(this.numItems, this.numFactors);
        this.userBudget = new VectorBasedDenseVector(this.numUsers);
        for (u = 0; u < this.numUsers; ++u) {
            this.userBudget.set(u, this.trainMatrix.row(u).sum());
        }
        this.d_scalar = this.a / this.b * (double)this.numItems;
    }

    @Override
    protected void trainModel() throws LibrecException {
        this.computeExpectations();
        this.commputeSums();
        VectorBasedDenseVector phi = new VectorBasedDenseVector(this.numFactors);
        for (int iter = 1; iter <= this.numIterations; ++iter) {
            this.clearState();
            for (int u = 0; u < this.numUsers; ++u) {
                SequentialSparseVector items = this.trainMatrix.row(u);
                for (Vector.VectorEntry ve : items) {
                    int itemIdx = ve.index();
                    double y = ve.get();
                    this.getPhi(u, itemIdx, phi);
                    if (y > 1.0) {
                        phi.times(y);
                    }
                    this.zUsers.set(u, this.zUsers.row(u).plus(phi));
                    this.zItems.set(itemIdx, this.zItems.row(itemIdx).plus(phi));
                }
            }
            this.updateSticks();
            this.update_sticks_scalar();
            this.updateItems();
        }
        this.itemFactors = new DenseMatrix(this.beta.value);
    }

    private void computeExpectations() {
        this.s.computeExpectations();
        this.beta.computeExpectations();
    }

    private void commputeSums() {
        this.computeEThetaSum();
        this.computeEBetaSum();
    }

    private void computeEThetaSum() {
        for (int k = 0; k < this.numFactors; ++k) {
            this.eThetaSum.set(k, 0.0);
        }
        for (int u = 0; u < this.numUsers; ++u) {
            for (int k = 0; k < this.numFactors; ++k) {
                this.eThetaSum.set(k, this.eThetaSum.get(k) + this.s.value.get(u) * this.pi.get(u, k));
            }
        }
    }

    private void computeEBetaSum() {
        for (int k = 0; k < this.numFactors; ++k) {
            this.eBetaSum.set(k, this.beta.value.column(k).sum());
        }
    }

    private void clearState() {
        int k;
        for (int i = 0; i < this.numItems; ++i) {
            for (k = 0; k < this.numFactors; ++k) {
                this.zItems.set(i, k, 0.0);
            }
        }
        for (int u = 0; u < this.numUsers; ++u) {
            for (k = 0; k < this.numFactors; ++k) {
                this.zUsers.set(u, k, 0.0);
            }
        }
    }

    private void getPhi(int userIdx, int itemIdx, DenseVector phi) {
        for (int k = 0; k < this.numFactors; ++k) {
            phi.set(k, this.elogTheta(userIdx, k) + this.beta.logValue.get(itemIdx, k));
        }
        VectorBasedDenseVector s = new VectorBasedDenseVector(2);
        s.set(0, this.logSum(phi));
        s.set(1, this.compute_mult_normalizer_infsum(userIdx));
        double logSum = this.logSum(s);
        this.lognormalize(phi, logSum);
    }

    private double compute_mult_normalizer_infsum(int u) {
        double x = this.elogtheta_at_truncation(u) + this.elogbeta_at_truncation();
        double elogv_t = Gamma.digamma(this.alpha) - Gamma.digamma(1.0 + this.alpha);
        return x - Math.log(1.0 - Math.exp(elogv_t));
    }

    private double elogtheta_at_truncation(int u) {
        double elogvt = Gamma.digamma(1.0) - Gamma.digamma(1.0 + this.alpha);
        return this.s.logValue.get(u) + elogvt + this.logpi.get(u, this.numFactors - 1) - Math.log(this.v.get(u, this.numFactors - 1)) + Math.log(1.0 - this.v.get(u, this.numFactors - 1));
    }

    private double elogbeta_at_truncation() {
        return Gamma.digamma(this.a) - Math.log(this.b);
    }

    private double logSum(DenseVector vector) {
        double r = 0.0;
        int count = 0;
        for (Vector.VectorEntry ve : vector) {
            double value = ve.get();
            if (count == 0) {
                r = value;
                ++count;
                continue;
            }
            if (value < r) {
                r += Math.log(1.0 + Math.exp(value - r));
                continue;
            }
            r = value + Math.log(1.0 + Math.exp(r - value));
        }
        return r;
    }

    private void lognormalize(DenseVector vector, double logSum) {
        for (Vector.VectorEntry ve : vector) {
            ve.set(ve.get() - logSum);
        }
    }

    private double elogTheta(int u, int k) {
        return this.s.logValue.get(u) + this.logpi.get(u, k);
    }

    private void updateSticks() {
        for (int u = 0; u < this.numUsers; ++u) {
            double lw = 0.0;
            for (int k = 0; k < this.numFactors; ++k) {
                double lpid_at_T = 0.0;
                double v_value = this.v.get(u, k);
                double[] p_and_sum = this.sum_of_prod_in_range(u, k + 1);
                double Auk = this.s.value.get(u) * (-1.0 + this.prob_at_k(u, k) / v_value) + p_and_sum[1] / (1.0 - v_value) + this.compute_scalar_rate_infsum(u) / (1.0 - v_value);
                if (Math.abs(Auk) < 1.0E-30) {
                    double x = this.zUsers.get(u, k) / this.B(u, k, Auk);
                    v_value = x < 1.0E-30 ? 1.0E-30 : x;
                    this.v.set(u, k, v_value);
                } else {
                    v_value = this.solve_quadratic(Auk, this.B(u, k, Auk), -this.zUsers.get(u, k));
                    this.v.set(u, k, v_value);
                }
                if (k > 0) {
                    this.pi.set(u, k, this.pi.get(u, k - 1) / this.v.get(u, k - 1) * (1.0 - this.v.get(u, k - 1)) * v_value);
                } else {
                    this.pi.set(u, k, v_value);
                }
                this.logpi.set(u, k, Math.log(this.pi.get(u, k)));
                this.userFactors.set(u, k, this.pi.get(u, k) * this.s.value.get(u));
            }
        }
    }

    private double[] sum_of_prod_in_range(int u, int K) {
        double sum2 = 0.0;
        double p = this.convert_oldpi_to_new(u, K - 1);
        for (int k = K; k < this.numFactors; ++k) {
            p = this.convert_oldpi_to_new(p, u, k);
            sum2 += p * this.eBetaSum.get(k);
        }
        return new double[]{p, sum2};
    }

    private double convert_oldpi_to_new(int u, int k) {
        if (k == 0) {
            return this.pi.get(u, 0);
        }
        return this.pi.get(u, k - 1) * ((1.0 - this.v.get(u, k - 1)) / this.v.get(u, k - 1)) * this.v.get(u, k);
    }

    private double convert_oldpi_to_new(double pi_at_kminus1, int u, int k) {
        return pi_at_kminus1 * ((1.0 - this.v.get(u, k - 1)) / this.v.get(u, k - 1)) * this.v.get(u, k);
    }

    private double compute_scalar_rate_infsum(int u) {
        double y = this.computeY(u);
        return y * this.d_scalar;
    }

    private double computeY(int u) {
        double v_value = this.v.get(u, this.numFactors - 1);
        return this.pi.get(u, this.numFactors - 1) / v_value * (1.0 - v_value);
    }

    private double prob_at_k(int u, int k) {
        return this.convert_oldpi_to_new(u, k) * this.eBetaSum.get(k);
    }

    private double B(int u, int k, double Auk) {
        return this.alpha - 1.0 + this.zUsers.get(u, k) - Auk + this.compute_zUser_sum(u, k);
    }

    private double compute_zUser_sum(int u, int tok) {
        double sum2 = 0.0;
        for (int k = 0; k < tok; ++k) {
            sum2 += this.zUsers.get(u, k);
        }
        return this.userBudget.get(u) - sum2;
    }

    private double solve_quadratic(double a, double b, double c) {
        double s1 = (-b + Math.sqrt(b * b - 4.0 * a * c)) / (2.0 * a);
        double s2 = (-b - Math.sqrt(b * b - 4.0 * a * c)) / (2.0 * a);
        if (s1 > 0.0 && s1 <= 1.0 && s2 > 0.0 && s2 <= 1.0) {
            if (s1 < s2) {
                return s1 + 1.0E-30;
            }
            return s2 + 1.0E-30;
        }
        if (s1 > 0.0 && s1 <= 1.0) {
            return s1;
        }
        if (s2 > 0.0 && s2 <= 1.0) {
            return s2;
        }
        if (Math.abs(s1 - 0.0) < 1.0E-30) {
            return 1.0E-30;
        }
        if (Math.abs(s1 - 1.0) < 1.0E-30) {
            return 1.0;
        }
        if (Math.abs(s2 - 0.0) < 1.0E-30) {
            return 1.0E-30;
        }
        if (Math.abs(s2 - 1.0) < 1.0E-30) {
            return 1.0;
        }
        return s1;
    }

    private void update_sticks_scalar() {
        for (int u = 0; u < this.numUsers; ++u) {
            double infsum = this.compute_scalar_rate_infsum(u);
            double fnsum = this.compute_scalar_rate_finitesum(u);
            this.s.shape.set(u, this.userBudget.get(u));
            this.s.rate.set(u, fnsum + infsum);
            this.s.computeExpectations();
            this.computeEThetaSum();
        }
    }

    private double compute_scalar_rate_finitesum(int u) {
        return this.pi.row(u).sum() + this.eBetaSum.sum();
    }

    private void updateItems() {
        for (int i = 0; i < this.numItems; ++i) {
            this.beta.shape.set(i, this.zItems.row(i));
        }
        this.beta.rate = new VectorBasedDenseVector(this.eThetaSum);
        this.beta.computeExpectations();
        this.computeEBetaSum();
    }

    public class GammaDenseMatrixGR {
        protected int numRows;
        protected int numColumns;
        protected double shapePrior;
        protected double ratePrior;
        protected DenseMatrix shape;
        protected DenseVector rate;
        protected DenseMatrix value;
        protected DenseMatrix logValue;

        public GammaDenseMatrixGR(int _numRows, int _numColumns, double shapeP, double rateP) {
            this.shapePrior = shapeP;
            this.ratePrior = rateP;
            this.numRows = _numRows;
            this.numColumns = _numColumns;
            this.shape = new DenseMatrix(this.numRows, this.numColumns);
            this.rate = new VectorBasedDenseVector(this.numColumns);
            this.value = new DenseMatrix(this.numRows, this.numColumns);
            this.logValue = new DenseMatrix(this.numRows, this.numColumns);
        }

        public void init() {
            int j;
            int i;
            for (i = 0; i < this.numRows; ++i) {
                for (j = 0; j < this.numColumns; ++j) {
                    this.shape.set(i, j, this.shapePrior + 0.01 * Randoms.uniform(0.0, 1.0));
                }
            }
            for (int j2 = 0; j2 < this.numColumns; ++j2) {
                this.rate.set(j2, this.ratePrior + 0.1 * Randoms.uniform(0.0, 1.0));
            }
            for (i = 0; i < this.numRows; ++i) {
                for (j = 0; j < this.numColumns; ++j) {
                    this.value.set(i, j, this.shape.get(i, j) / this.rate.get(j));
                    this.logValue.set(i, j, Gamma.digamma(this.shape.get(i, j)) - Math.log(this.rate.get(j)));
                }
            }
        }

        public void computeExpectations() {
            double a = 0.0;
            double b = 0.0;
            for (int i = 0; i < this.numRows; ++i) {
                for (int j = 0; j < this.numColumns; ++j) {
                    double av = this.shape.get(i, j);
                    double bv = this.rate.get(j);
                    a = av <= 0.0 ? 1.0E-30 : av;
                    b = bv <= 0.0 ? 1.0E-30 : bv;
                    this.value.set(i, j, a / b);
                    this.logValue.set(i, j, Gamma.digamma(a) - Math.log(b));
                }
            }
        }
    }

    public class GammaDenseMatrix {
        protected int numRows;
        protected int numColumns;
        protected double shapePrior;
        protected double ratePrior;
        protected DenseMatrix shape;
        protected DenseMatrix rate;
        protected DenseMatrix value;
        protected DenseMatrix logValue;

        public GammaDenseMatrix(int _numRows, int _numColumns, double shapeP, double rateP) {
            this.shapePrior = shapeP;
            this.ratePrior = rateP;
            this.numRows = _numRows;
            this.numColumns = _numColumns;
            this.shape = new DenseMatrix(this.numRows, this.numColumns);
            this.rate = new DenseMatrix(this.numRows, this.numColumns);
            this.value = new DenseMatrix(this.numRows, this.numColumns);
            this.logValue = new DenseMatrix(this.numRows, this.numColumns);
        }

        public void init() {
            int j;
            int i;
            for (i = 0; i < this.numRows; ++i) {
                for (j = 0; j < this.numColumns; ++j) {
                    this.shape.set(i, j, this.shapePrior + 0.01 * Randoms.uniform(0.0, 1.0));
                    if (i == 0) {
                        this.rate.set(0, j, this.ratePrior + 0.1 * Randoms.uniform(0.0, 1.0));
                        continue;
                    }
                    this.rate.set(i, j, this.rate.get(0, j));
                }
            }
            for (i = 0; i < this.numRows; ++i) {
                for (j = 0; j < this.numColumns; ++j) {
                    this.value.set(i, j, this.shape.get(i, j) / this.rate.get(i, j));
                    this.logValue.set(i, j, Gamma.digamma(this.shape.get(i, j)) - Math.log(this.rate.get(i, j)));
                }
            }
        }

        public void computeExpectations() {
            double a = 0.0;
            double b = 0.0;
            for (int i = 0; i < this.numRows; ++i) {
                for (int j = 0; j < this.numColumns; ++j) {
                    double av = this.shape.get(i, j);
                    double bv = this.rate.get(i, j);
                    a = av <= 0.0 ? 1.0E-30 : av;
                    b = bv <= 0.0 ? 1.0E-30 : bv;
                    this.value.set(i, j, a / b);
                    this.logValue.set(i, j, Gamma.digamma(a) - Math.log(b));
                }
            }
        }
    }

    public class GammaDenseVector {
        protected int size;
        protected double shapePrior;
        protected double ratePrior;
        protected DenseVector shape;
        protected DenseVector rate;
        protected DenseVector value;
        protected DenseVector logValue;

        public GammaDenseVector(int _size, double shapeP, double rateP) {
            this.shapePrior = shapeP;
            this.ratePrior = rateP;
            this.size = _size;
            this.shape = new VectorBasedDenseVector(this.size);
            this.rate = new VectorBasedDenseVector(this.size);
            this.value = new VectorBasedDenseVector(this.size);
            this.logValue = new VectorBasedDenseVector(this.size);
        }

        public void init() {
            int i;
            for (i = 0; i < this.size; ++i) {
                this.shape.set(i, this.shapePrior + 0.01 * Randoms.uniform(0.0, 1.0));
                this.rate.set(i, this.ratePrior + 0.1 * Randoms.uniform(0.0, 1.0));
            }
            for (i = 0; i < this.size; ++i) {
                this.value.set(i, this.shape.get(i) / this.rate.get(i));
                this.logValue.set(i, Gamma.digamma(this.shape.get(i)) - Math.log(this.rate.get(i)));
            }
        }

        public void init2(double v) {
            for (int i = 0; i < this.size; ++i) {
                this.shape.set(i, this.shapePrior + 0.01 * Randoms.uniform(0.0, 1.0));
                this.rate.set(i, this.ratePrior + v);
            }
        }

        public void computeExpectations() {
            double a = 0.0;
            double b = 0.0;
            this.value = new VectorBasedDenseVector(this.size);
            this.logValue = new VectorBasedDenseVector(this.size);
            for (int i = 0; i < this.size; ++i) {
                double av = this.shape.get(i);
                double bv = this.rate.get(i);
                a = av <= 0.0 ? 1.0E-30 : av;
                b = bv <= 0.0 ? 1.0E-30 : bv;
                this.value.set(i, a / b);
                this.logValue.set(i, Gamma.digamma(a) - Math.log(b));
            }
        }
    }
}

