/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender;

import java.util.ArrayList;
import java.util.List;
import net.librec.common.LibrecException;
import net.librec.data.structure.AbstractBaseDataEntry;
import net.librec.data.structure.BaseContextRatingDataEntry;
import net.librec.data.structure.BaseDataList;
import net.librec.data.structure.BaseRankingDataEntry;
import net.librec.data.structure.LibrecDataList;
import net.librec.job.progress.ProgressBar;
import net.librec.math.structure.DataSet;
import net.librec.math.structure.SequentialAccessSparseMatrix;
import net.librec.math.structure.SparseTensor;
import net.librec.math.structure.TensorEntry;
import net.librec.recommender.AbstractRecommender;
import net.librec.recommender.item.KeyValue;
import net.librec.recommender.item.RecommendedList;

public abstract class TensorRecommender
extends AbstractRecommender {
    protected SparseTensor trainTensor;
    protected SparseTensor testTensor;
    protected SparseTensor validTensor;
    protected int numDimensions;
    protected int[] dimensions;
    protected int numFactors;
    protected float learnRate;
    protected float maxLearnRate;
    protected int numIterations;
    protected int userDimension;
    protected int itemDimension;
    protected float reg;
    protected int numUsers;
    protected int numItems;
    protected double maxRate = Double.MIN_NORMAL;
    protected double minRate = Double.MAX_VALUE;
    protected double globalMean;
    protected ProgressBar progressBar;
    protected SequentialAccessSparseMatrix trainMatrix;
    protected SequentialAccessSparseMatrix testMatrix;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        this.learnRate = this.conf.getFloat("rec.iterator.learnrate", Float.valueOf(0.01f)).floatValue();
        this.maxLearnRate = this.conf.getFloat("rec.iterator.learnrate.maximum", Float.valueOf(1000.0f)).floatValue();
        this.numFactors = this.conf.getInt("rec.factor.number", 10);
        this.reg = this.conf.getFloat("rec.tensor.regularization", Float.valueOf(0.01f)).floatValue();
        this.numIterations = this.conf.getInt("rec.iterator.maximum", 100);
        this.trainTensor = (SparseTensor)this.getDataModel().getTrainDataSet();
        this.testTensor = (SparseTensor)this.getDataModel().getTestDataSet();
        this.validTensor = (SparseTensor)this.getDataModel().getValidDataSet();
        this.trainMatrix = this.trainTensor.rateMatrix();
        this.testMatrix = this.testTensor.rateMatrix();
        int size = 0;
        double sum2 = 0.0;
        for (TensorEntry trainTensorEntry : this.trainTensor) {
            double rate = trainTensorEntry.get();
            this.maxRate = this.maxRate > rate ? this.maxRate : rate;
            this.minRate = this.minRate < rate ? this.minRate : rate;
            ++size;
            sum2 += rate;
        }
        this.globalMean = sum2 / (double)size;
        this.numDimensions = this.trainTensor.numDimensions();
        this.dimensions = this.trainTensor.dimensions();
        this.userDimension = this.trainTensor.getUserDimension();
        this.itemDimension = this.trainTensor.getItemDimension();
        this.numUsers = this.trainTensor.dimensions()[this.userDimension];
        this.numItems = this.trainTensor.dimensions()[this.itemDimension];
        if (verbose) {
            this.progressBar = new ProgressBar(100.0, 100);
        }
        int[] numDroppedItemsArray = new int[this.numUsers];
        int maxNumTestItemsByUser = 0;
        for (int userIdx = 0; userIdx < this.numUsers; ++userIdx) {
            numDroppedItemsArray[userIdx] = this.numItems - this.trainMatrix.row(userIdx).getNumEntries();
            int numTestItemsByUser = this.testMatrix.row(userIdx).getNumEntries();
            maxNumTestItemsByUser = maxNumTestItemsByUser < numTestItemsByUser ? numTestItemsByUser : maxNumTestItemsByUser;
        }
        int[] itemPurchasedCount = new int[this.numItems];
        for (int itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
            itemPurchasedCount[itemIdx] = this.trainMatrix.column(itemIdx).getNumEntries() + this.testMatrix.column(itemIdx).getNumEntries();
        }
        this.conf.setInts("rec.eval.auc.dropped.num", numDroppedItemsArray);
        this.conf.setInt("rec.eval.key.test.max.num", maxNumTestItemsByUser);
        this.conf.setInt("rec.eval.item.num", this.testMatrix.columnSize());
        this.conf.setInts("rec.eval.item.purchase.num", itemPurchasedCount);
    }

    @Override
    protected abstract void trainModel() throws LibrecException;

    @Override
    public RecommendedList recommendRank() throws LibrecException {
        BaseDataList<AbstractBaseDataEntry> librecDataList = new BaseDataList<AbstractBaseDataEntry>();
        for (int userIdx = 0; userIdx < this.numUsers; ++userIdx) {
            BaseRankingDataEntry baseRankingDataEntry = new BaseRankingDataEntry(userIdx);
            librecDataList.addDataEntry(baseRankingDataEntry);
        }
        return this.recommendRank(librecDataList);
    }

    @Override
    public RecommendedList recommendRank(LibrecDataList<AbstractBaseDataEntry> dataList) throws LibrecException {
        int numDataEntries = dataList.size();
        RecommendedList recommendedList = new RecommendedList(this.numUsers);
        ArrayList<Integer> contextList = new ArrayList<Integer>();
        for (int contextIdx2 = 0; contextIdx2 < numDataEntries; ++contextIdx2) {
            contextList.add(contextIdx2);
            recommendedList.addList(new ArrayList<KeyValue<Integer, Double>>());
        }
        contextList.parallelStream().forEach(contextIdx -> {
            BaseRankingDataEntry baseRankingDataEntry = (BaseRankingDataEntry)dataList.getDataEntry((int)contextIdx);
            int userIdx = baseRankingDataEntry.getUserId();
            int[] items = this.trainMatrix.row(userIdx).getIndices();
            ArrayList<KeyValue<Integer, Double>> itemValueList = new ArrayList<KeyValue<Integer, Double>>();
            int trainItemIndex = 0;
            for (int itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
                if (trainItemIndex < items.length && items[trainItemIndex] == itemIdx) {
                    ++trainItemIndex;
                    continue;
                }
                double predictRating = 0.0;
                try {
                    predictRating = this.predict(userIdx, itemIdx);
                }
                catch (LibrecException e) {
                    e.printStackTrace();
                }
                if (Double.isNaN(predictRating)) continue;
                itemValueList.add(new KeyValue<Integer, Double>(itemIdx, predictRating));
            }
            recommendedList.setList((int)contextIdx, (List<KeyValue<Integer, Double>>)itemValueList);
            recommendedList.topNRankByIndex((int)contextIdx, this.topN);
        });
        if (recommendedList.size() == 0) {
            throw new IndexOutOfBoundsException("No item is recommended, there is something error in the recommendation algorithm! Please check it!");
        }
        return recommendedList;
    }

    @Override
    public RecommendedList recommendRating(DataSet predictDataSet) throws LibrecException {
        this.testTensor = (SparseTensor)predictDataSet;
        RecommendedList recommendedList = new RecommendedList(this.numUsers);
        for (int userIdx = 0; userIdx < this.numUsers; ++userIdx) {
            recommendedList.addList(new ArrayList<KeyValue<Integer, Double>>());
        }
        for (TensorEntry testTensorEntry : this.testTensor) {
            int[] keys = testTensorEntry.keys();
            int userIdx = testTensorEntry.key(this.userDimension);
            int itemIdx = testTensorEntry.key(this.itemDimension);
            double predictRating = this.predict(keys, true);
            if (Double.isNaN(predictRating)) {
                predictRating = this.globalMean;
            }
            recommendedList.add(userIdx, itemIdx, predictRating);
        }
        return recommendedList;
    }

    @Override
    public RecommendedList recommendRating(LibrecDataList<AbstractBaseDataEntry> dataList) throws LibrecException {
        int numDataEntries = dataList.size();
        RecommendedList recommendedList = new RecommendedList(numDataEntries);
        for (int userIdx = 0; userIdx < this.numUsers; ++userIdx) {
            recommendedList.addList(new ArrayList<KeyValue<Integer, Double>>());
        }
        for (int contextIdx = 0; contextIdx < numDataEntries; ++contextIdx) {
            recommendedList.addList(new ArrayList<KeyValue<Integer, Double>>());
            BaseContextRatingDataEntry baseContextRatingDataEntry = (BaseContextRatingDataEntry)dataList.getDataEntry(contextIdx);
            int userIdx = baseContextRatingDataEntry.getUserId();
            int[] itemIdsArray = baseContextRatingDataEntry.getItemIdsArray();
            int[][] contexts = baseContextRatingDataEntry.getContexts();
            for (int index = 0; index < itemIdsArray.length; ++index) {
                int itemIdx = itemIdsArray[index];
                int[] keys = new int[contexts[index].length + 2];
                keys[0] = userIdx;
                keys[1] = itemIdx;
                System.arraycopy(contexts[index], 0, keys, 2, contexts[index].length);
                double predictRating = this.predict(keys, true);
                if (Double.isNaN(predictRating)) {
                    predictRating = this.globalMean;
                }
                recommendedList.add(contextIdx, itemIdx, predictRating);
            }
        }
        return recommendedList;
    }

    protected abstract double predict(int[] var1) throws LibrecException;

    protected double predict(int userIdx, int itemIdx) throws LibrecException {
        return 0.0;
    }

    protected double predict(int[] keys, boolean bound) throws LibrecException {
        double predictRating = this.predict(keys);
        if (bound) {
            if (predictRating > this.maxRate) {
                predictRating = this.maxRate;
            } else if (predictRating < this.minRate) {
                predictRating = this.minRate;
            }
        }
        return predictRating;
    }

    @Override
    protected boolean isConverged(int iter) throws LibrecException {
        float delta_loss = (float)(this.lastLoss - this.loss);
        if (verbose) {
            String recName = this.getClass().getSimpleName();
            String info = recName + " iter " + iter + ": loss = " + this.loss + ", delta_loss = " + delta_loss;
            this.LOG.info(info);
        }
        if (Double.isNaN(this.loss) || Double.isInfinite(this.loss)) {
            throw new LibrecException("Loss = NaN or Infinity: current settings does not fit the recommender! Change the settings and try again!");
        }
        return (double)Math.abs(delta_loss) < 1.0E-5;
    }

    protected void updateLRate(int iter) {
        if ((double)this.learnRate < 0.0) {
            return;
        }
        if (this.isBoldDriver && iter > 1) {
            this.learnRate = Math.abs(this.lastLoss) > Math.abs(this.loss) ? this.learnRate * 1.05f : this.learnRate * 0.5f;
        } else if (this.decay > 0.0f && this.decay < 1.0f) {
            this.learnRate *= this.decay;
        }
        if (this.maxLearnRate > 0.0f && this.learnRate > this.maxLearnRate) {
            this.learnRate = this.maxLearnRate;
        }
        this.lastLoss = this.loss;
    }
}

