/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.cf.rating;

import java.util.ArrayList;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.DenseVector;
import net.librec.math.structure.MatrixEntry;
import net.librec.math.structure.RowSequentialAccessSparseMatrix;
import net.librec.math.structure.SequentialSparseVector;
import net.librec.math.structure.VectorBasedDenseVector;
import net.librec.recommender.MatrixFactorizationRecommender;

public class BPMFRecommender
extends MatrixFactorizationRecommender {
    private double userMu0;
    private double userBeta0;
    private double userWishartScale0;
    private double itemMu0;
    private double itemBeta0;
    private double itemWishartScale0;
    private DenseVector userMu;
    private DenseVector itemMu;
    private DenseMatrix userWishartScale;
    private DenseMatrix itemWishartScale;
    private double userBeta;
    private double itemBeta;
    private double userWishartNu;
    private double itemWishartNu;
    private double ratingSigma;
    private int gibbsIterations;
    private RowSequentialAccessSparseMatrix predictMatrix;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        this.userMu0 = this.conf.getDouble("rec.recommender.user.mu", 0.0);
        this.userBeta0 = this.conf.getDouble("rec.recommender.user.beta", 1.0);
        this.userWishartScale0 = this.conf.getDouble("rec.recommender.user.wishart.scale", 1.0);
        this.itemMu0 = this.conf.getDouble("rec.recommender.item.mu", 0.0);
        this.itemBeta0 = this.conf.getDouble("rec.recommender.item.beta", 1.0);
        this.itemWishartScale0 = this.conf.getDouble("rec.recommender.item.wishart.scale", 1.0);
        this.ratingSigma = this.conf.getDouble("rec.recommender.rating.sigma", 2.0);
        this.gibbsIterations = this.conf.getInt("rec.recommender.gibbs.iterations", 1);
    }

    protected void initModel() throws LibrecException {
        this.userMu = new VectorBasedDenseVector(this.numFactors);
        this.userMu.assign((index, value) -> this.userMu0);
        this.itemMu = new VectorBasedDenseVector(this.numFactors);
        this.itemMu.assign((index, value) -> this.itemMu0);
        this.userBeta = this.userBeta0;
        this.itemBeta = this.itemBeta0;
        this.userWishartScale = new DenseMatrix(this.numFactors, this.numFactors);
        this.itemWishartScale = new DenseMatrix(this.numFactors, this.numFactors);
        for (int i = 0; i < this.numFactors; ++i) {
            this.userWishartScale.set(i, i, this.userWishartScale0);
            this.itemWishartScale.set(i, i, this.itemWishartScale0);
        }
        this.userWishartScale.inverse();
        this.itemWishartScale.inverse();
        this.userWishartNu = this.numFactors;
        this.itemWishartNu = this.numFactors;
        this.predictMatrix = new RowSequentialAccessSparseMatrix(this.testMatrix);
    }

    @Override
    protected void trainModel() throws LibrecException {
        this.initModel();
        ArrayList<SequentialSparseVector> userTrainVectors = new ArrayList<SequentialSparseVector>(this.numUsers);
        ArrayList<SequentialSparseVector> itemTrainVectors = new ArrayList<SequentialSparseVector>(this.numItems);
        for (int u = 0; u < this.numUsers; ++u) {
            userTrainVectors.add(this.trainMatrix.row(u));
        }
        for (int i = 0; i < this.numItems; ++i) {
            itemTrainVectors.add(this.trainMatrix.column(i));
        }
        VectorBasedDenseVector mu_u = new VectorBasedDenseVector(this.numFactors);
        VectorBasedDenseVector mu_m = new VectorBasedDenseVector(this.numFactors);
        for (int f = 0; f < this.numFactors; ++f) {
            mu_u.set(f, this.userFactors.column(f).mean());
            mu_m.set(f, this.itemFactors.column(f).mean());
        }
        DenseMatrix variance_u = this.userFactors.covariance().inverse();
        DenseMatrix variance_m = this.itemFactors.covariance().inverse();
        HyperParameters userHyperParameters = new HyperParameters(mu_u, variance_u);
        HyperParameters itemHyperParameters = new HyperParameters(mu_m, variance_m);
        for (int iter = 0; iter < this.numIterations; ++iter) {
            int startnum;
            userHyperParameters = this.samplingHyperParameters(userHyperParameters, this.userFactors, this.userMu, this.userBeta, this.userWishartScale, this.userWishartNu);
            itemHyperParameters = this.samplingHyperParameters(itemHyperParameters, this.itemFactors, this.itemMu, this.itemBeta, this.itemWishartScale, this.itemWishartNu);
            for (int gibbsIteration = 0; gibbsIteration < this.gibbsIterations; ++gibbsIteration) {
                DenseVector updatedParameters;
                int count;
                SequentialSparseVector ratings;
                for (int u = 0; u < this.numUsers; ++u) {
                    ratings = (SequentialSparseVector)userTrainVectors.get(u);
                    count = ratings.getNumEntries();
                    if (count == 0) continue;
                    updatedParameters = this.updateParameters(this.itemFactors, ratings, userHyperParameters);
                    this.userFactors.row(u).assign((index, value) -> updatedParameters.get(index));
                }
                for (int i = 0; i < this.numItems; ++i) {
                    ratings = (SequentialSparseVector)itemTrainVectors.get(i);
                    count = ratings.getNumEntries();
                    if (count == 0) continue;
                    updatedParameters = this.updateParameters(this.userFactors, ratings, itemHyperParameters);
                    this.itemFactors.row(i).assign((index, value) -> updatedParameters.get(index));
                }
            }
            if (iter == 1) {
                for (MatrixEntry me : this.testMatrix) {
                    int u = me.row();
                    int i = me.column();
                    this.predictMatrix.set(u, i, 0.0);
                }
            }
            if (iter <= (startnum = 0)) continue;
            for (MatrixEntry me : this.testMatrix) {
                int userIdx = me.row();
                int itemIdx = me.column();
                double predictValue = (this.predictMatrix.get(userIdx, itemIdx) * (double)(iter - 1 - startnum) + this.globalMean + this.userFactors.row(userIdx).dot(this.itemFactors.row(itemIdx))) / (double)(iter - startnum);
                this.predictMatrix.set(userIdx, itemIdx, predictValue);
            }
        }
    }

    protected HyperParameters samplingHyperParameters(HyperParameters hyperParameters, DenseMatrix factors, DenseVector normalMu0, double normalBeta0, DenseMatrix WishartScale0, double WishartNu0) throws LibrecException {
        DenseMatrix normalVariance;
        int numRows2 = factors.rowSize();
        int numColumns = factors.columnSize();
        VectorBasedDenseVector mean = new VectorBasedDenseVector(this.numFactors);
        for (int i = 0; i < numColumns; ++i) {
            mean.set(i, factors.column(i).mean());
        }
        DenseMatrix populationVariance = factors.covariance();
        double betaPost = normalBeta0 + (double)numRows2;
        DenseVector muPost = normalMu0.times(normalBeta0).plus(mean.times(numRows2)).times(1.0 / betaPost);
        DenseMatrix WishartScalePost = WishartScale0.plus(populationVariance.times(numRows2));
        DenseVector muError = normalMu0.minus(mean);
        WishartScalePost = WishartScalePost.plus(muError.outer(muError).times(normalBeta0 * (double)numRows2 / betaPost));
        WishartScalePost = WishartScalePost.inverse();
        DenseMatrix variance = Randoms.wishart(WishartScalePost = WishartScalePost.plus(WishartScalePost.transpose()).times(0.5), numRows2 + numColumns);
        if (variance != null) {
            hyperParameters.variance = variance;
        }
        if ((normalVariance = hyperParameters.variance.times(normalBeta0).inverse().cholesky()) != null) {
            normalVariance = normalVariance.transpose();
            VectorBasedDenseVector normalRdn = new VectorBasedDenseVector(numColumns);
            for (int f = 0; f < this.numFactors; ++f) {
                normalRdn.set(f, Randoms.gaussian(0.0, 1.0));
            }
            hyperParameters.mu = normalVariance.times(normalRdn).plus(muPost);
        }
        return hyperParameters;
    }

    protected DenseVector updateParameters(DenseMatrix factors, SequentialSparseVector ratings, HyperParameters hyperParameters) throws LibrecException {
        int num = ratings.getNumEntries();
        DenseMatrix XX = new DenseMatrix(num, this.numFactors);
        VectorBasedDenseVector ratingsReg = new VectorBasedDenseVector(num);
        int index = 0;
        for (int j : ratings.getIndices()) {
            ratingsReg.set(index, ratings.get(j) - this.globalMean);
            XX.row(index).assign((index1, value) -> factors.row(j).get(index1));
            ++index;
        }
        DenseMatrix covar = hyperParameters.variance.plus(XX.transpose().times(XX).times(this.ratingSigma)).inverse();
        DenseVector mu = XX.transpose().times(ratingsReg).times(this.ratingSigma);
        DenseVector result = mu.plus(hyperParameters.variance.times(hyperParameters.mu));
        mu.assign((index1, value) -> result.get(index1));
        mu = covar.times(mu);
        VectorBasedDenseVector factorVector = new VectorBasedDenseVector(this.numFactors);
        DenseMatrix lam = covar.cholesky();
        if (lam != null) {
            lam = lam.transpose();
            for (int f = 0; f < this.numFactors; ++f) {
                factorVector.set(f, Randoms.gaussian(0.0, 1.0));
            }
            DenseVector w1_P1_u = lam.times(factorVector).plus(mu);
            for (int f = 0; f < this.numFactors; ++f) {
                factorVector.set(f, w1_P1_u.get(f));
            }
        }
        return factorVector;
    }

    @Override
    protected double predict(int userIdx, int itemIdx) {
        return this.predictMatrix.get(userIdx, itemIdx);
    }

    public class HyperParameters {
        public DenseVector mu;
        public DenseMatrix variance;

        HyperParameters(DenseVector _mu, DenseMatrix _variance) {
            this.mu = _mu;
            this.variance = _variance;
        }
    }
}

