/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.content;

import com.google.common.collect.BiMap;
import com.google.common.collect.HashBasedTable;
import java.util.HashMap;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Maths;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.DataFrame;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.MatrixEntry;
import net.librec.math.structure.SequentialAccessSparseMatrix;
import net.librec.math.structure.SparseStringMatrix;
import net.librec.math.structure.SparseTensor;
import net.librec.math.structure.TensorEntry;
import net.librec.math.structure.VectorBasedDenseVector;
import net.librec.recommender.TensorRecommender;
import net.librec.util.StringUtil;
import org.apache.commons.lang.StringUtils;

public class HFTRecommender
extends TensorRecommender {
    protected SequentialAccessSparseMatrix trainMatrix;
    protected SparseStringMatrix reviewMatrix;
    protected DenseMatrix topicToWord;
    protected SparseStringMatrix topicAssignment;
    protected int K = 10;
    protected int numberOfWords;
    protected VectorBasedDenseVector userBiases;
    protected VectorBasedDenseVector itemBiases;
    protected DenseMatrix userFactors;
    protected DenseMatrix itemFactors;
    protected float initMean;
    protected float initStd;
    protected double regBias;
    protected float regUser;
    protected float regItem;
    public BiMap<Integer, String> reviewMappingData;
    protected StringUtil str = new StringUtil();
    protected Randoms rn = new Randoms();
    protected double[][] thetaus;
    protected double[][] phiks;

    @Override
    protected void setup() throws LibrecException {
        int[] entryKeys;
        super.setup();
        this.reviewMappingData = DataFrame.getInnerMapping("review").inverse();
        this.regBias = this.conf.getDouble("rec.bias.regularization", 0.01);
        this.regUser = this.conf.getFloat("rec.user.regularization", Float.valueOf(0.01f)).floatValue();
        this.regItem = this.conf.getFloat("rec.item.regularization", Float.valueOf(0.01f)).floatValue();
        this.trainTensor = (SparseTensor)this.getDataModel().getTrainDataSet();
        this.userBiases = new VectorBasedDenseVector(this.numUsers);
        this.itemBiases = new VectorBasedDenseVector(this.numItems);
        this.userFactors = new DenseMatrix(this.numUsers, this.numFactors);
        this.itemFactors = new DenseMatrix(this.numItems, this.numFactors);
        this.K = this.numFactors;
        this.initMean = 0.0f;
        this.initStd = 0.1f;
        this.userBiases.init(this.initMean, this.initStd);
        this.itemBiases.init(this.initMean, this.initStd);
        this.numberOfWords = 0;
        this.trainMatrix = this.trainTensor.rateMatrix();
        HashBasedTable<Integer, Integer, String[]> res = HashBasedTable.create();
        HashMap<String, String> iwDict = new HashMap<String, String>();
        for (TensorEntry te : this.trainTensor) {
            String[] fReviewContent;
            entryKeys = te.keys();
            int userIndex = entryKeys[0];
            int itemIndex = entryKeys[1];
            int reviewIndex = entryKeys[2];
            String[] reviewContent = (String[])this.reviewMappingData.get(reviewIndex);
            for (String word : fReviewContent = reviewContent.split(":")) {
                if (iwDict.containsKey(word) || !StringUtils.isNotEmpty(word)) continue;
                iwDict.put(word, String.valueOf(this.numberOfWords));
                ++this.numberOfWords;
            }
            res.put(userIndex, itemIndex, reviewContent);
        }
        for (TensorEntry te : this.testTensor) {
            String[] fReviewContent;
            entryKeys = te.keys();
            int reviewIndex = entryKeys[2];
            String reviewContent = (String)this.reviewMappingData.get(reviewIndex);
            for (String word : fReviewContent = reviewContent.split(":")) {
                if (iwDict.containsKey(word) || !StringUtils.isNotEmpty(word)) continue;
                iwDict.put(word, String.valueOf(this.numberOfWords));
                ++this.numberOfWords;
            }
        }
        this.LOG.info("number of users : " + this.numUsers);
        this.LOG.info("number of Items : " + this.numItems);
        this.LOG.info("number of words : " + this.numberOfWords);
        this.reviewMatrix = new SparseStringMatrix(this.numUsers, this.numItems, res);
        this.topicToWord = new DenseMatrix(this.K, this.numberOfWords);
        this.topicToWord.init(0.1);
        this.topicAssignment = new SparseStringMatrix(this.reviewMatrix);
        this.thetaus = new double[this.numUsers][this.K];
        this.phiks = new double[this.K][this.numberOfWords];
        for (MatrixEntry me : this.trainMatrix) {
            int u = me.row();
            int j = me.column();
            String words = this.reviewMatrix.get(u, j);
            String[] wordsList = words.split(":");
            Object[] topicList = new String[wordsList.length];
            for (int i = 0; i < wordsList.length; ++i) {
                topicList[i] = Integer.toString(Randoms.uniform(this.K));
            }
            String s = StringUtil.toString(topicList, ":");
            this.topicAssignment.set(u, j, s);
        }
        this.calculateThetas();
        this.calculatePhis();
    }

    protected void sampleZ() throws Exception {
        this.calculateThetas();
        this.calculatePhis();
        for (MatrixEntry me : this.trainMatrix) {
            int j;
            int u = me.row();
            String words = this.reviewMatrix.get(u, j = me.column());
            if (StringUtils.isEmpty(words)) continue;
            String[] wordsList = words.split(":");
            String s = this.sampleTopicsToWords(wordsList, u);
            this.topicAssignment.set(u, j, s);
        }
    }

    protected double[] updateArray(double[] oldValues, double[] newValues) throws Exception {
        double[] newDoubles;
        boolean containNan = false;
        for (double doubleValue : newDoubles = Maths.softmax(newValues)) {
            if (!Double.isNaN(doubleValue)) continue;
            containNan = true;
            break;
        }
        if (!containNan) {
            return newValues;
        }
        return oldValues;
    }

    protected void calculateThetas() {
        for (int i = 0; i < this.numUsers; ++i) {
            try {
                this.thetaus[i] = this.updateArray(this.thetaus[i], Maths.softmax(this.userFactors.row(i).getValues()));
                continue;
            }
            catch (Exception e) {
                e.printStackTrace();
            }
        }
    }

    protected void calculatePhis() {
        for (int i = 0; i < this.K; ++i) {
            try {
                this.phiks[i] = this.updateArray(this.phiks[i], Maths.softmax(this.topicToWord.row(i).getValues()));
                continue;
            }
            catch (Exception e) {
                e.printStackTrace();
            }
        }
    }

    protected String sampleTopicsToWords(String[] wordsList, int u) throws Exception {
        Object[] topicList = new String[wordsList.length];
        for (int i = 0; i < wordsList.length; ++i) {
            double[] topicDistribute = new double[this.K];
            for (int s = 0; s < this.K; ++s) {
                topicDistribute[s] = this.thetaus[u][s] * this.phiks[s][Integer.parseInt(wordsList[i])];
            }
            topicDistribute = Maths.norm(topicDistribute);
            topicList[i] = Integer.toString(Randoms.discrete(topicDistribute));
        }
        return StringUtil.toString(topicList, ":");
    }

    @Override
    protected void trainModel() {
        int iter = 1;
        while ((double)iter <= this.conf.getDouble("rec.iterator.maximum")) {
            for (int sgditer = 1; sgditer <= 5; ++sgditer) {
                this.loss = 0.0;
                for (MatrixEntry me : this.trainMatrix) {
                    int u = me.row();
                    int j = me.column();
                    double ruj = me.get();
                    String[] ws = this.reviewMatrix.get(u, j).split(":");
                    String[] wk = this.topicAssignment.get(u, j).split(":");
                    double pred = this.predict(u, j);
                    double euj = ruj - pred;
                    this.loss += euj * euj;
                    double bu = this.userBiases.get(u);
                    double sgd = euj - this.regBias * bu;
                    this.userBiases.plus(u, (double)this.learnRate * sgd);
                    double bj = this.itemBiases.get(j);
                    sgd = euj - this.regBias * bj;
                    this.itemBiases.plus(j, (double)this.learnRate * sgd);
                    if (StringUtils.isEmpty(ws[0])) continue;
                    for (int f = 0; f < this.numFactors; ++f) {
                        double puf = this.userFactors.get(u, f);
                        double qjf = this.itemFactors.get(j, f);
                        double sgd_u = euj * qjf - (double)this.regUser * puf;
                        double sgd_j = euj * puf - (double)this.regItem * qjf;
                        this.userFactors.plus(u, f, (double)this.learnRate * sgd_u);
                        this.itemFactors.plus(j, f, (double)this.learnRate * sgd_j);
                        for (int x = 0; x < ws.length; ++x) {
                            int k = Integer.parseInt(wk[x]);
                            if (f == k) {
                                this.userFactors.plus(u, f, (double)this.learnRate * (1.0 - this.thetaus[u][k]));
                            } else {
                                this.userFactors.plus(u, f, (double)this.learnRate * -this.thetaus[u][k]);
                            }
                            this.loss -= Maths.log(this.thetaus[u][k] * this.phiks[k][Integer.parseInt(ws[x])], 2);
                        }
                    }
                    for (int x = 0; x < ws.length; ++x) {
                        int k = Integer.parseInt(wk[x]);
                        for (int ss = 0; ss < this.numberOfWords; ++ss) {
                            if (ss == Integer.parseInt(ws[x])) {
                                this.topicToWord.plus(k, Integer.parseInt(ws[x]), (double)this.learnRate * (-1.0 + this.phiks[k][Integer.parseInt(ws[x])]));
                                continue;
                            }
                            this.topicToWord.plus(k, Integer.parseInt(ws[x]), (double)this.learnRate * this.phiks[k][Integer.parseInt(ws[x])]);
                        }
                    }
                }
                this.loss *= 0.5;
            }
            this.LOG.info(" iter:" + iter + ", loss:" + this.loss);
            try {
                this.LOG.info(" iter:" + iter + ", sampling");
                this.sampleZ();
                this.LOG.info(" iter:" + iter + ", sample finished");
            }
            catch (Exception e) {
                e.printStackTrace();
            }
            ++iter;
        }
    }

    @Override
    protected double predict(int[] indices) {
        return this.predict(indices[0], indices[1]);
    }

    @Override
    protected double predict(int userIdx, int itemIdx) {
        return this.userFactors.row(userIdx).dot(this.itemFactors.row(itemIdx)) + this.userBiases.get(userIdx) + this.itemBiases.get(itemIdx) + this.globalMean;
    }
}

