/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.cf.ranking;

import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Gamma;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.SequentialAccessSparseMatrix;
import net.librec.math.structure.VectorBasedDenseVector;
import net.librec.recommender.MatrixProbabilisticGraphicalRecommender;
import net.librec.util.RatingContext;
import org.apache.commons.lang.ArrayUtils;

@ModelData(value={"isRanking", "itembigram", "userTopicProbs", "topicPreItemCurItemProbs"})
public class ItemBigramRecommender
extends MatrixProbabilisticGraphicalRecommender {
    private Map<Integer, List<Integer>> userItemsMap;
    private int[][][] topicPreItemCurItemNum;
    private DenseMatrix topicItemProbs;
    private double[][][] topicPreItemCurItemProbs;
    private double[][][] topicPreItemCurItemSumProbs;
    private DenseMatrix beta;
    protected VectorBasedDenseVector alpha;
    protected int numTopics;
    protected float initAlpha;
    protected float initBeta;
    protected DenseMatrix userTopicProbsSum;
    protected DenseMatrix userTopicNumbers;
    protected VectorBasedDenseVector userTokenNumbers;
    protected DenseMatrix userTopicProbs;
    protected Table<Integer, Integer, Integer> topicAssignments;
    private SequentialAccessSparseMatrix timeMatrix;
    private Table<Integer, Integer, Double> timeTable;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        this.numTopics = this.conf.getInt("rec.topic.number", 10);
        this.initAlpha = this.conf.getFloat("rec.user.dirichlet.prior", Float.valueOf(0.01f)).floatValue();
        this.initBeta = this.conf.getFloat("rec.topic.dirichlet.prior", Float.valueOf(0.01f)).floatValue();
        this.timeMatrix = (SequentialAccessSparseMatrix)this.getDataModel().getDatetimeDataSet();
        this.timeTable = this.timeMatrix.getDataTable();
        this.userItemsMap = new HashMap<Integer, List<Integer>>();
        for (int userIdx = 0; userIdx < this.numUsers; ++userIdx) {
            int[] itemIndexes = this.trainMatrix.row(userIdx).getIndices();
            Integer[] inputBoxed = ArrayUtils.toObject(itemIndexes);
            List<Integer> unsortedItems = Arrays.asList(inputBoxed);
            int size = unsortedItems.size();
            ArrayList<RatingContext> rcs = new ArrayList<RatingContext>(size);
            for (Integer n : unsortedItems) {
                rcs.add(new RatingContext(userIdx, n, this.timeTable.get(userIdx, n).longValue()));
            }
            Collections.sort(rcs);
            ArrayList<Integer> sortedItems = new ArrayList<Integer>(size);
            for (RatingContext rc : rcs) {
                sortedItems.add(rc.getItem());
            }
            this.userItemsMap.put(userIdx, sortedItems);
        }
        this.userTopicNumbers = new DenseMatrix(this.numUsers, this.numTopics);
        this.userTokenNumbers = new VectorBasedDenseVector(this.numUsers);
        this.topicPreItemCurItemNum = new int[this.numTopics][this.numItems + 1][this.numItems];
        this.topicItemProbs = new DenseMatrix(this.numTopics, this.numItems + 1);
        this.userTopicProbsSum = new DenseMatrix(this.numUsers, this.numTopics);
        this.topicPreItemCurItemSumProbs = new double[this.numTopics][this.numItems + 1][this.numItems];
        this.topicPreItemCurItemProbs = new double[this.numTopics][this.numItems + 1][this.numItems];
        this.alpha = new VectorBasedDenseVector(this.numTopics);
        this.alpha.assign((index, value) -> this.initAlpha);
        this.beta = new DenseMatrix(this.numTopics, this.numItems + 1);
        this.beta.assign((row, column, value) -> this.initBeta);
        this.topicAssignments = HashBasedTable.create();
        for (Map.Entry<Integer, List<Integer>> userItemEntry : this.userItemsMap.entrySet()) {
            int userIdx = userItemEntry.getKey();
            List<Integer> itemIdxList = userItemEntry.getValue();
            for (int itemIdxIndex = 0; itemIdxIndex < itemIdxList.size(); ++itemIdxIndex) {
                int itemIdx = itemIdxList.get(itemIdxIndex);
                int topicIdx = (int)(Math.random() * (double)this.numTopics);
                this.topicAssignments.put(userIdx, itemIdx, topicIdx);
                this.userTopicNumbers.plus(userIdx, topicIdx, 1.0);
                this.userTokenNumbers.plus(userIdx, 1.0);
                int n = itemIdxIndex > 0 ? itemIdxList.get(itemIdxIndex - 1) : this.numItems;
                int[] nArray = this.topicPreItemCurItemNum[topicIdx][n];
                int n2 = itemIdx;
                nArray[n2] = nArray[n2] + 1;
                this.topicItemProbs.plus(topicIdx, n, 1.0);
            }
        }
    }

    @Override
    protected void eStep() {
        double sumAlpha = this.alpha.sum();
        for (Map.Entry<Integer, List<Integer>> userItemEntry : this.userItemsMap.entrySet()) {
            int userIdx = userItemEntry.getKey();
            List<Integer> items = userItemEntry.getValue();
            for (int itemIdxIndex = 0; itemIdxIndex < items.size(); ++itemIdxIndex) {
                int topicInIdx;
                int itemIdx = items.get(itemIdxIndex);
                int topicIdx = this.topicAssignments.get(userIdx, itemIdx);
                this.userTopicNumbers.plus(userIdx, topicIdx, -1.0);
                this.userTokenNumbers.plus(userIdx, -1.0);
                int preItemIdx = itemIdxIndex > 0 ? items.get(itemIdxIndex - 1) : this.numItems;
                int[] nArray = this.topicPreItemCurItemNum[topicIdx][preItemIdx];
                int n = itemIdx;
                nArray[n] = nArray[n] - 1;
                this.topicItemProbs.plus(topicIdx, preItemIdx, -1.0);
                double[] tempUserProbs = new double[this.numTopics];
                for (topicInIdx = 0; topicInIdx < this.numTopics; ++topicInIdx) {
                    double tempValue1 = (this.userTopicNumbers.get(userIdx, topicIdx) + this.alpha.get(topicInIdx)) / (this.userTokenNumbers.get(userIdx) + sumAlpha);
                    double tempValue2 = ((double)this.topicPreItemCurItemNum[topicInIdx][preItemIdx][itemIdx] + this.beta.get(topicInIdx, preItemIdx)) / (this.topicItemProbs.get(topicInIdx, preItemIdx) + this.beta.row(topicInIdx).sum());
                    tempUserProbs[topicInIdx] = tempValue1 * tempValue2;
                }
                for (topicInIdx = 1; topicInIdx < this.numTopics; ++topicInIdx) {
                    int n2 = topicInIdx;
                    tempUserProbs[n2] = tempUserProbs[n2] + tempUserProbs[topicInIdx - 1];
                }
                double rand = Randoms.uniform() * tempUserProbs[this.numTopics - 1];
                for (topicIdx = 0; topicIdx < this.numTopics && !(rand < tempUserProbs[topicIdx]); ++topicIdx) {
                }
                this.topicAssignments.put(userIdx, itemIdx, topicIdx);
                this.userTopicNumbers.plus(userIdx, topicIdx, 1.0);
                this.userTokenNumbers.plus(userIdx, 1.0);
                int[] nArray2 = this.topicPreItemCurItemNum[topicIdx][preItemIdx];
                int n3 = itemIdx;
                nArray2[n3] = nArray2[n3] + 1;
                this.topicItemProbs.plus(topicIdx, preItemIdx, 1.0);
            }
        }
    }

    @Override
    protected void mStep() {
        int topicIdx;
        double sumAlpha = this.alpha.sum();
        double alphaDigamma = Gamma.digamma(sumAlpha);
        double alphaDenominator = 0.0;
        for (int userIdx = 0; userIdx < this.numUsers; ++userIdx) {
            alphaDenominator += Gamma.digamma(this.userTokenNumbers.get(userIdx) + sumAlpha) - alphaDigamma;
        }
        for (topicIdx = 0; topicIdx < this.numTopics; ++topicIdx) {
            double alphaTopicValue = this.alpha.get(topicIdx);
            double alphaTopicDigamma = Gamma.digamma(alphaTopicValue);
            double numerator = 0.0;
            for (int userIdx = 0; userIdx < this.numUsers; ++userIdx) {
                numerator += Gamma.digamma(this.userTopicNumbers.get(userIdx, topicIdx) + alphaTopicValue) - alphaTopicDigamma;
            }
            if (numerator == 0.0) continue;
            this.alpha.set(topicIdx, alphaTopicValue * (numerator / alphaDenominator));
        }
        for (topicIdx = 0; topicIdx < this.numTopics; ++topicIdx) {
            int itemIdx;
            double betaTopicValue = this.beta.row(topicIdx).sum();
            double betaTopicDigamma = Gamma.digamma(betaTopicValue);
            double[] itemDenominators = new double[this.numItems + 1];
            for (itemIdx = 0; itemIdx < this.numItems + 1; ++itemIdx) {
                itemDenominators[itemIdx] = Gamma.digamma(this.topicItemProbs.get(topicIdx, itemIdx) + betaTopicValue) - betaTopicDigamma;
            }
            for (itemIdx = 0; itemIdx < this.numItems + 1; ++itemIdx) {
                double betaTopicItemValue = this.beta.get(topicIdx, itemIdx);
                double betaTopicItemDigamma = Gamma.digamma(betaTopicItemValue);
                double numerator = 0.0;
                double denominator = 0.0;
                for (int preItemIdx = 0; preItemIdx < this.numItems; ++preItemIdx) {
                    numerator += Gamma.digamma((double)this.topicPreItemCurItemNum[topicIdx][itemIdx][preItemIdx] + betaTopicItemValue) - betaTopicItemDigamma;
                    denominator += itemDenominators[itemIdx];
                }
                if (numerator == 0.0) continue;
                this.beta.set(topicIdx, itemIdx, betaTopicItemValue * (numerator / denominator));
            }
        }
    }

    @Override
    protected void readoutParams() {
        double val;
        double sumAlpha = this.alpha.sum();
        for (int userIdx = 0; userIdx < this.numTopics; ++userIdx) {
            for (int topicIdx = 0; topicIdx < this.numTopics; ++topicIdx) {
                val = (this.userTopicNumbers.get(userIdx, topicIdx) + this.alpha.get(topicIdx)) / (this.userTokenNumbers.get(userIdx) + sumAlpha);
                this.userTopicProbsSum.plus(userIdx, topicIdx, val);
            }
        }
        for (int topicIdx = 0; topicIdx < this.numTopics; ++topicIdx) {
            double betaTopicValue = this.beta.row(topicIdx).sum();
            for (int itemIdx = 0; itemIdx < this.numItems + 1; ++itemIdx) {
                int preItemIdx = 0;
                while (preItemIdx < this.numItems) {
                    val = ((double)this.topicPreItemCurItemNum[topicIdx][itemIdx][preItemIdx] + this.beta.get(topicIdx, itemIdx)) / (this.topicItemProbs.get(topicIdx, itemIdx) + betaTopicValue);
                    double[] dArray = this.topicPreItemCurItemSumProbs[topicIdx][itemIdx];
                    int n = preItemIdx++;
                    dArray[n] = dArray[n] + val;
                }
            }
        }
        ++this.numStats;
    }

    @Override
    protected void estimateParams() {
        this.userTopicProbs = this.userTopicProbsSum.times(1.0 / (double)this.numStats);
        for (int topicIdx = 0; topicIdx < this.numTopics; ++topicIdx) {
            for (int itemIdx = 0; itemIdx < this.numItems + 1; ++itemIdx) {
                for (int preItemIdx = 0; preItemIdx < this.numItems; ++preItemIdx) {
                    this.topicPreItemCurItemProbs[topicIdx][itemIdx][preItemIdx] = this.topicPreItemCurItemSumProbs[topicIdx][itemIdx][preItemIdx] / (double)this.numStats;
                }
            }
        }
    }

    @Override
    protected double predict(int userIdx, int itemIdx) throws LibrecException {
        List<Integer> items = this.userItemsMap.get(userIdx);
        int preItemIdx = items.size() < 1 ? this.numItems : items.get(items.size() - 1);
        double predictRating = 0.0;
        for (int topicIdx = 0; topicIdx < this.numTopics; ++topicIdx) {
            predictRating += this.userTopicProbs.get(userIdx, topicIdx) * this.topicPreItemCurItemProbs[topicIdx][preItemIdx][itemIdx];
        }
        return predictRating;
    }
}

