/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.content;

import com.google.common.collect.BiMap;
import com.google.common.collect.HashBasedTable;
import com.google.common.collect.HashBiMap;
import com.google.common.collect.Table;
import java.util.ArrayList;
import java.util.Iterator;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Maths;
import net.librec.math.structure.DataFrame;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.MatrixEntry;
import net.librec.math.structure.SequentialAccessSparseMatrix;
import net.librec.math.structure.SequentialSparseVector;
import net.librec.math.structure.SparseTensor;
import net.librec.math.structure.TensorEntry;
import net.librec.math.structure.Vector;
import net.librec.math.structure.VectorBasedDenseVector;
import net.librec.recommender.TensorRecommender;

public class TopicMFATRecommender
extends TensorRecommender {
    protected SequentialAccessSparseMatrix trainMatrix;
    protected SequentialAccessSparseMatrix W;
    protected DenseMatrix theta;
    protected DenseMatrix phi;
    protected double K1;
    protected double K2;
    protected VectorBasedDenseVector userBiases;
    protected VectorBasedDenseVector itemBiases;
    protected DenseMatrix userFactors;
    protected DenseMatrix itemFactors;
    protected int numTopics;
    protected int numWords;
    protected int numDocuments;
    protected BiMap<Integer, String> reviewMappingData;
    protected double lambda;
    protected double lambdaU;
    protected double lambdaV;
    protected double lambdaB;
    protected BiMap<String, Integer> wordIdToWordIndex;
    protected Table<Integer, Integer, Integer> userItemToDocument;
    protected float initMean;
    protected float initStd;
    protected int[][] documentTopWordIdices;
    protected int topNum = 5;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        this.reviewMappingData = DataFrame.getInnerMapping("review").inverse();
        this.lambda = this.conf.getDouble("rec.regularization.lambda", 0.001);
        this.lambdaU = this.conf.getDouble("rec.regularization.lambdaU", 0.001);
        this.lambdaV = this.conf.getDouble("rec.regularization.lambdaV", 0.001);
        this.lambdaB = this.conf.getDouble("rec.regularization.lambdaB", 0.001);
        this.numTopics = this.conf.getInt("rec.topic.number", 10);
        this.learnRate = this.conf.getFloat("rec.iterator.learnrate", Float.valueOf(0.01f)).floatValue();
        this.numIterations = this.conf.getInt("rec.iterator.maximum", 10);
        this.trainTensor = (SparseTensor)this.getDataModel().getTrainDataSet();
        this.trainMatrix = this.trainTensor.rateMatrix();
        this.numDocuments = this.trainMatrix.size();
        this.wordIdToWordIndex = HashBiMap.create();
        HashBasedTable<Integer, Integer, Double> res = HashBasedTable.create();
        int rowCount = 0;
        this.userItemToDocument = HashBasedTable.create();
        for (TensorEntry te : this.trainTensor) {
            String[] fReviewContent;
            int[] entryKeys = te.keys();
            int userIndex = entryKeys[0];
            int itemIndex = entryKeys[1];
            int reviewIndex = entryKeys[2];
            this.userItemToDocument.put(userIndex, itemIndex, rowCount);
            String reviewContent = (String)this.reviewMappingData.get(reviewIndex);
            for (String word : fReviewContent = reviewContent.split(":")) {
                if (this.wordIdToWordIndex.containsKey(word)) continue;
                this.wordIdToWordIndex.put(word, this.numWords);
                ++this.numWords;
            }
            ArrayList wordIndexList = new ArrayList();
            for (String word : fReviewContent) {
                wordIndexList.add(this.wordIdToWordIndex.get(word));
            }
            double denominator = wordIndexList.size();
            Iterator iterator = wordIndexList.iterator();
            while (iterator.hasNext()) {
                int wordIdx = (Integer)iterator.next();
                Double oldValue = (Double)res.get(rowCount, wordIdx);
                if (oldValue == null) {
                    oldValue = 0.0;
                }
                double newValue = oldValue + 1.0 / denominator;
                res.put(rowCount, wordIdx, newValue);
            }
            ++rowCount;
        }
        this.W = new SequentialAccessSparseMatrix(this.numDocuments, this.numWords, res);
        this.initMean = this.conf.getFloat("rec.init.mean", Float.valueOf(0.0f)).floatValue();
        this.initStd = this.conf.getFloat("rec.init.std", Float.valueOf(0.01f)).floatValue();
        this.userBiases = new VectorBasedDenseVector(this.numUsers);
        this.userBiases.init(this.initMean, this.initStd);
        this.itemBiases = new VectorBasedDenseVector(this.numItems);
        this.itemBiases.init(this.initMean, this.initStd);
        this.userFactors = new DenseMatrix(this.numUsers, this.numTopics);
        this.userFactors.init(this.initMean, this.initStd);
        this.itemFactors = new DenseMatrix(this.numItems, this.numTopics);
        this.itemFactors.init(this.initMean, this.initStd);
        this.K1 = this.initStd;
        this.K2 = this.initStd;
        this.theta = new DenseMatrix(this.numDocuments, this.numTopics);
        this.calculateTheta();
        this.phi = new DenseMatrix(this.numTopics, this.numWords);
        this.phi.init(0.01);
        this.LOG.info("number of users : " + this.numUsers);
        this.LOG.info("number of Items : " + this.numItems);
        this.LOG.info("number of words : " + this.wordIdToWordIndex.size());
    }

    @Override
    protected void trainModel() throws LibrecException {
        for (int iter = 1; iter <= this.numIterations; ++iter) {
            int j;
            int i;
            this.loss = 0.0;
            double wordLoss = 0.0;
            for (MatrixEntry me : this.trainMatrix) {
                i = me.row();
                j = me.column();
                int documentId = this.userItemToDocument.get(i, j);
                double y_true = me.get();
                double y_pred = this.predict(i, j);
                double error = y_true - y_pred;
                this.loss += error * error;
                double userBiasValue = this.userBiases.get(i);
                this.userBiases.plus(i, (double)this.learnRate * (error - this.lambdaB * userBiasValue));
                this.loss += this.lambdaB * userBiasValue * userBiasValue;
                double itemBiasValue = this.itemBiases.get(j);
                this.itemBiases.plus(j, (double)this.learnRate * (error - this.lambdaB * itemBiasValue));
                this.loss += this.lambdaB * itemBiasValue * itemBiasValue;
                for (int factorIdx = 0; factorIdx < this.numTopics; ++factorIdx) {
                    double userFactorValue = this.userFactors.get(i, factorIdx);
                    double itemFactorValue = this.itemFactors.get(j, factorIdx);
                    this.userFactors.plus(i, factorIdx, (double)this.learnRate * (error * itemFactorValue - this.lambdaU * userFactorValue));
                    this.itemFactors.plus(j, factorIdx, (double)this.learnRate * (error * userFactorValue - this.lambdaV * itemFactorValue));
                    this.loss += this.lambdaU * userFactorValue * userFactorValue + this.lambdaV * itemFactorValue * itemFactorValue;
                    SequentialSparseVector wordVec = this.W.row(documentId);
                    for (Vector.VectorEntry ve : wordVec) {
                        int wordIdx = ve.index();
                        double w_true = ve.get();
                        double w_pred = this.theta.row(documentId).dot(this.phi.column(wordIdx));
                        double w_error = w_true - w_pred;
                        wordLoss += w_error;
                        double derivative = 0.0;
                        for (int topicIdx = 0; topicIdx < this.numTopics; ++topicIdx) {
                            derivative = factorIdx == topicIdx ? (derivative += w_error * this.phi.get(topicIdx, wordIdx) * this.theta.get(documentId, topicIdx) * (1.0 - this.theta.get(documentId, topicIdx))) : (derivative += w_error * this.phi.get(topicIdx, wordIdx) * this.theta.get(documentId, topicIdx) * -this.theta.get(documentId, factorIdx));
                            this.K1 += (double)this.learnRate * this.lambda * w_error * this.phi.get(topicIdx, wordIdx) * this.theta.get(documentId, topicIdx) * (1.0 - this.theta.get(documentId, topicIdx)) * Math.abs(this.userFactors.get(i, topicIdx));
                            this.K2 += (double)this.learnRate * this.lambda * w_error * this.phi.get(topicIdx, wordIdx) * this.theta.get(documentId, topicIdx) * (1.0 - this.theta.get(documentId, topicIdx)) * Math.abs(this.itemFactors.get(j, topicIdx));
                        }
                        this.userFactors.plus(i, factorIdx, (double)this.learnRate * this.K1 * derivative);
                        this.itemFactors.plus(j, factorIdx, (double)this.learnRate * this.K2 * derivative);
                    }
                }
            }
            this.LOG.info(" iter:" + iter + ", finish factors update");
            this.calculateTheta();
            this.LOG.info(" iter:" + iter + ", finish theta update");
            DenseMatrix thetaTW = this.theta.transpose().times(this.W);
            DenseMatrix denominatorMatrix = this.theta.transpose().times(this.theta).times(this.phi);
            for (i = 0; i < this.numTopics; ++i) {
                for (j = 0; j < this.numWords; ++j) {
                    double numerator = this.phi.get(i, j) * thetaTW.get(i, j);
                    double denominator = denominatorMatrix.get(i, j);
                    this.phi.set(i, j, numerator / denominator);
                }
            }
            this.LOG.info(" iter:" + iter + ", finish phi update");
            this.loss += (wordLoss /= (double)this.numTopics);
            this.loss *= 0.5;
            this.LOG.info(" iter:" + iter + ", loss:" + this.loss + ", wordLoss:" + wordLoss / 2.0);
        }
    }

    @Override
    protected double predict(int[] keys) throws LibrecException {
        return this.predict(keys[0], keys[1]);
    }

    @Override
    protected double predict(int userIdx, int itemIdx) {
        return this.userFactors.row(userIdx).dot(this.itemFactors.row(itemIdx)) + this.userBiases.get(userIdx) + this.itemBiases.get(itemIdx) + this.globalMean;
    }

    private void calculateTheta() {
        for (MatrixEntry me : this.trainMatrix) {
            int u = me.row();
            int i = me.column();
            double[] k1uAddk2v = new double[this.numTopics];
            for (int k = 0; k < this.numTopics; ++k) {
                k1uAddk2v[k] = Math.abs(this.userFactors.get(u, k)) * this.K1 + Math.abs(this.itemFactors.get(i, k)) * this.K2;
            }
            int documentIdx = this.userItemToDocument.get(u, i);
            try {
                VectorBasedDenseVector newValues = new VectorBasedDenseVector(Maths.softmax(k1uAddk2v));
                this.theta.set(documentIdx, newValues);
            }
            catch (Exception e) {
                e.printStackTrace();
            }
        }
    }
}

