/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.context.rating;

import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.DenseVector;
import net.librec.math.structure.MatrixEntry;
import net.librec.math.structure.RowSequentialAccessSparseMatrix;
import net.librec.math.structure.SequentialAccessSparseMatrix;
import net.librec.math.structure.SequentialSparseVector;
import net.librec.math.structure.Vector;
import net.librec.math.structure.VectorBasedDenseVector;
import net.librec.recommender.cf.rating.BiasedMFRecommender;

public class TimeSVDRecommender
extends BiasedMFRecommender {
    private static int numDays;
    private static long minTimestamp;
    private static long maxTimestamp;
    private static SequentialAccessSparseMatrix instantMatrix;
    private DenseVector userMeanDays;
    private int numSections;
    private DenseMatrix userImplicitFactors;
    private DenseMatrix itemImplicitFactors;
    private DenseMatrix userExplicitFactors;
    private DenseMatrix itemExplicitFactors;
    private DenseMatrix itemSectionBiases;
    private DenseMatrix userDayBiases;
    private DenseVector userBiasWeights;
    private Map<Integer, double[]>[] userDayFactors;
    private DenseVector userScales;
    private DenseMatrix userDayScales;
    private double beta = 0.4;
    private RowSequentialAccessSparseMatrix trainTimeMatrix;
    private RowSequentialAccessSparseMatrix testTimeMatrix;

    private static int days(long duration) {
        return (int)TimeUnit.SECONDS.toDays(duration);
    }

    private static int days(long t1, long t2) {
        return TimeSVDRecommender.days(Math.abs(t1 - t2));
    }

    @Override
    protected void setup() throws LibrecException {
        int userIndex;
        super.setup();
        this.beta = this.conf.getDouble("rec.timesvd.beta", 0.1);
        this.numSections = this.conf.getInt("rec.numBins", 20);
        instantMatrix = (SequentialAccessSparseMatrix)this.getDataModel().getDatetimeDataSet();
        this.getMaxAndMinTimeStamp();
        numDays = TimeSVDRecommender.days(maxTimestamp, minTimestamp) + 1;
        this.userBiasWeights = new VectorBasedDenseVector(this.numUsers);
        this.userBiasWeights.init(this.initMean, this.initStd);
        this.itemSectionBiases = new DenseMatrix(this.numItems, this.numSections);
        this.itemSectionBiases.init(this.initMean, this.initStd);
        this.itemImplicitFactors = new DenseMatrix(this.numItems, this.numFactors);
        this.itemImplicitFactors.init(this.initMean, this.initStd);
        this.userImplicitFactors = new DenseMatrix(this.numUsers, this.numFactors);
        this.userImplicitFactors.init(this.initMean, this.initStd);
        this.userDayBiases = new DenseMatrix(this.numUsers, numDays);
        this.userDayFactors = new Map[this.numUsers];
        for (int userIndex2 = 0; userIndex2 < this.numUsers; ++userIndex2) {
            this.userDayFactors[userIndex2] = new HashMap<Integer, double[]>();
        }
        this.userScales = new VectorBasedDenseVector(this.numUsers);
        this.userScales.init(this.initMean, this.initStd);
        this.userDayScales = new DenseMatrix(this.numUsers, numDays);
        this.userDayScales.init(this.initMean, this.initStd);
        this.userExplicitFactors = new DenseMatrix(this.numUsers, this.numFactors);
        this.userExplicitFactors.init(this.initMean, this.initStd);
        this.itemExplicitFactors = new DenseMatrix(this.numItems, this.numFactors);
        this.itemExplicitFactors.init(this.initMean, this.initStd);
        double sum2 = 0.0;
        int count = 0;
        this.trainTimeMatrix = new RowSequentialAccessSparseMatrix(this.trainMatrix, true);
        for (Object matrixEntry : this.trainTimeMatrix) {
            int userIndex3 = matrixEntry.row();
            int itemIndex = matrixEntry.column();
            int tempDay = TimeSVDRecommender.days((long)instantMatrix.get(userIndex3, itemIndex), minTimestamp);
            matrixEntry.set(tempDay);
            sum2 += (double)tempDay;
            double[] dayFactors = new double[this.numFactors];
            for (int factorIndex = 0; factorIndex < this.numFactors; ++factorIndex) {
                dayFactors[factorIndex] = Randoms.uniform(this.initMean, this.initStd);
            }
            this.userDayFactors[userIndex3].put(tempDay, dayFactors);
            this.userDayBiases.set(userIndex3, tempDay, Randoms.uniform(this.initMean, this.initStd));
            ++count;
        }
        double[] dayFactors = new double[this.numFactors];
        this.testTimeMatrix = new RowSequentialAccessSparseMatrix(this.testMatrix, true);
        for (MatrixEntry matrixEntry : this.testTimeMatrix) {
            userIndex = matrixEntry.row();
            int itemIndex = matrixEntry.column();
            int tempDay = TimeSVDRecommender.days((long)instantMatrix.get(userIndex, itemIndex), minTimestamp);
            matrixEntry.set(tempDay);
            if (this.userDayFactors[userIndex].containsKey(tempDay)) continue;
            this.userDayFactors[userIndex].put(tempDay, dayFactors);
        }
        System.gc();
        double globalMeanDays = sum2 / (double)count;
        this.userMeanDays = new VectorBasedDenseVector(this.numUsers);
        for (userIndex = 0; userIndex < this.numUsers; ++userIndex) {
            sum2 = 0.0;
            SequentialSparseVector userVector = this.trainTimeMatrix.row(userIndex);
            for (Vector.VectorEntry vectorEntry : userVector) {
                sum2 += (double)TimeSVDRecommender.days((long)vectorEntry.get(), minTimestamp);
            }
            double mean = userVector.size() > 0 ? (sum2 + 0.0) / (double)userVector.size() : globalMeanDays;
            this.userMeanDays.set(userIndex, mean);
        }
    }

    @Override
    protected void trainModel() throws LibrecException {
        VectorBasedDenseVector factorVector = new VectorBasedDenseVector(this.numFactors);
        for (int iterationStep = 1; iterationStep <= this.numIterations; ++iterationStep) {
            this.loss = 0.0;
            for (int userIndex = 0; userIndex < this.numUsers; ++userIndex) {
                SequentialSparseVector rateVector = this.trainMatrix.row(userIndex);
                SequentialSparseVector timeVector = this.trainTimeMatrix.row(userIndex);
                int size = rateVector.size();
                if (size == 0) continue;
                double[] step = new double[this.numFactors];
                for (Vector.VectorEntry vectorEntry : rateVector) {
                    factorVector.assign((index, value) -> this.itemImplicitFactors.row(vectorEntry.index()).get(index) + value);
                }
                double scale = Math.pow(size, -0.5);
                factorVector.assign((index, value) -> value * scale);
                for (Vector.VectorEntry vectorEntry : rateVector) {
                    int itemExplicitIndex = vectorEntry.index();
                    double rate = vectorEntry.get();
                    int days = (int)timeVector.getAtPosition(vectorEntry.position());
                    int section = this.section(days);
                    double deviation = this.deviation(userIndex, days);
                    double userBias = this.userBiases.get(userIndex);
                    double itemBias = this.itemBiases.get(itemExplicitIndex);
                    double userScale = this.userScales.get(userIndex);
                    double dayScale = this.userDayScales.get(userIndex, days);
                    double userDayBias = this.userDayBiases.get(userIndex, days);
                    double itemSectionBias = this.itemSectionBiases.get(itemExplicitIndex, section);
                    double userWeight = this.userBiasWeights.get(userIndex);
                    double predict2 = this.globalMean + (itemBias + itemSectionBias) * (userScale + dayScale);
                    predict2 += userBias + userWeight * deviation + userDayBias;
                    DenseVector itemExplicitVector = this.itemExplicitFactors.row(itemExplicitIndex);
                    double sum2 = factorVector.dot(itemExplicitVector);
                    predict2 += sum2;
                    double[] dayFactors = this.userDayFactors[userIndex].get(days);
                    for (int factorIndex = 0; factorIndex < this.numFactors; ++factorIndex) {
                        double qik = this.itemExplicitFactors.get(itemExplicitIndex, factorIndex);
                        double puk = this.userExplicitFactors.get(userIndex, factorIndex) + this.userImplicitFactors.get(userIndex, factorIndex) * deviation + dayFactors[factorIndex];
                        predict2 += puk * qik;
                    }
                    double error = predict2 - rate;
                    this.loss += error * error;
                    double sgd = error * (userScale + dayScale) + this.regBias * itemBias;
                    this.itemBiases.plus(itemExplicitIndex, (double)(-this.learnRate) * sgd);
                    this.loss += this.regBias * itemBias * itemBias;
                    sgd = error * (userScale + dayScale) + this.regBias * itemSectionBias;
                    this.itemSectionBiases.plus(itemExplicitIndex, section, (double)(-this.learnRate) * sgd);
                    this.loss += this.regBias * itemSectionBias * itemSectionBias;
                    sgd = error * (itemBias + itemSectionBias) + this.regBias * userScale;
                    this.userScales.plus(userIndex, (double)(-this.learnRate) * sgd);
                    this.loss += this.regBias * userScale * userScale;
                    sgd = error * (itemBias + itemSectionBias) + this.regBias * dayScale;
                    this.userDayScales.plus(userIndex, days, (double)(-this.learnRate) * sgd);
                    this.loss += this.regBias * dayScale * dayScale;
                    sgd = error + this.regBias * userBias;
                    this.userBiases.plus(userIndex, (double)(-this.learnRate) * sgd);
                    this.loss += this.regBias * userBias * userBias;
                    sgd = error * deviation + this.regBias * userWeight;
                    this.userBiasWeights.plus(userIndex, (double)(-this.learnRate) * sgd);
                    this.loss += this.regBias * userWeight * userWeight;
                    sgd = error + this.regBias * userDayBias;
                    double delta = userDayBias - (double)this.learnRate * sgd;
                    this.userDayBiases.set(userIndex, days, delta);
                    this.loss += this.regBias * userDayBias * userDayBias;
                    int factorIndex = 0;
                    while (factorIndex < this.numFactors) {
                        double userExplicitFactor = this.userExplicitFactors.get(userIndex, factorIndex);
                        double itemExplicitFactor = this.itemExplicitFactors.get(itemExplicitIndex, factorIndex);
                        double userImplicitFactor = this.userImplicitFactors.get(userIndex, factorIndex);
                        delta = dayFactors[factorIndex];
                        sum2 = 0.0;
                        sgd = error * itemExplicitFactor + (double)this.regUser * userExplicitFactor;
                        this.userExplicitFactors.plus(userIndex, factorIndex, (double)(-this.learnRate) * sgd);
                        this.loss += (double)this.regUser * userExplicitFactor * userExplicitFactor;
                        for (Vector.VectorEntry explicitVectorEntry : rateVector) {
                            int itemImplicitIndex = explicitVectorEntry.index();
                            sum2 += this.itemImplicitFactors.get(itemImplicitIndex, factorIndex);
                        }
                        sgd = error * (userExplicitFactor + userImplicitFactor * deviation + delta + scale * sum2) + (double)this.regItem * itemExplicitFactor;
                        this.itemExplicitFactors.plus(itemExplicitIndex, factorIndex, (double)(-this.learnRate) * sgd);
                        this.loss += (double)this.regItem * itemExplicitFactor * itemExplicitFactor;
                        sgd = error * itemExplicitFactor * deviation + (double)this.regUser * userImplicitFactor;
                        this.userImplicitFactors.plus(userIndex, factorIndex, (double)(-this.learnRate) * sgd);
                        this.loss += (double)this.regUser * userImplicitFactor * userImplicitFactor;
                        sgd = error * itemExplicitFactor + (double)this.regUser * delta;
                        this.loss += (double)this.regUser * delta * delta;
                        dayFactors[factorIndex] = delta -= (double)this.learnRate * sgd;
                        int n = factorIndex++;
                        step[n] = step[n] + error * scale * itemExplicitFactor;
                    }
                }
                for (Vector.VectorEntry vectorEntry : rateVector) {
                    int itemImplicitIndex = vectorEntry.index();
                    for (int factorIndex = 0; factorIndex < this.numFactors; ++factorIndex) {
                        double itemImplicitFactor = this.itemImplicitFactors.get(itemImplicitIndex, factorIndex);
                        double sgd = step[factorIndex] + (double)this.regItem * itemImplicitFactor * (double)size;
                        this.itemImplicitFactors.plus(itemImplicitIndex, factorIndex, (double)(-this.learnRate) * sgd);
                        this.loss += (double)this.regItem * itemImplicitFactor * itemImplicitFactor * (double)size;
                    }
                }
            }
            this.loss *= 0.5;
            if (this.isConverged(iterationStep)) break;
            this.updateLRate(iterationStep);
            this.lastLoss = this.loss;
        }
    }

    @Override
    protected double predict(int userIndex, int itemIndex) {
        int days = (int)this.testTimeMatrix.get(userIndex, itemIndex);
        int section = this.section(days);
        double deviation = this.deviation(userIndex, days);
        double value = this.globalMean;
        value += (this.itemBiases.get(itemIndex) + this.itemSectionBiases.get(itemIndex, section)) * (this.userScales.get(userIndex) + this.userDayScales.get(userIndex, days));
        value += this.userBiases.get(userIndex) + this.userBiasWeights.get(userIndex) * deviation + this.userDayBiases.get(userIndex, days);
        SequentialSparseVector userVector = this.trainMatrix.row(userIndex);
        double sum2 = 0.0;
        DenseVector itemExplicitVector = this.itemExplicitFactors.row(itemIndex);
        for (Vector.VectorEntry vectorEntry : userVector) {
            DenseVector itemImplicitVector = this.itemImplicitFactors.row(vectorEntry.index());
            sum2 += itemImplicitVector.dot(itemExplicitVector);
        }
        double weight = userVector.size() > 0 ? Math.pow(userVector.size(), -0.5) : 0.0;
        value += sum2 * weight;
        double[] dayFactors = this.userDayFactors[userIndex].get(days);
        for (int factorIndex = 0; factorIndex < this.numFactors; ++factorIndex) {
            double itemExplicitFactor = this.itemExplicitFactors.get(itemIndex, factorIndex);
            double userExplicitFactor = this.userExplicitFactors.get(userIndex, factorIndex) + this.userImplicitFactors.get(userIndex, factorIndex) * deviation;
            value += (userExplicitFactor += dayFactors[factorIndex]) * itemExplicitFactor;
        }
        return value;
    }

    private double deviation(int userIndex, int days) {
        double mean = this.userMeanDays.get(userIndex);
        double deviation = (double)days - mean;
        return Math.signum(deviation) * Math.pow(Math.abs(deviation), this.beta);
    }

    private int section(int days) {
        return (int)((double)days / ((double)numDays + 0.0) * (double)this.numSections);
    }

    private void getMaxAndMinTimeStamp() {
        minTimestamp = Long.MAX_VALUE;
        maxTimestamp = Long.MIN_VALUE;
        for (MatrixEntry entry : instantMatrix) {
            long timeStamp = (long)entry.get();
            if (timeStamp < minTimestamp) {
                minTimestamp = timeStamp;
            }
            if (timeStamp <= maxTimestamp) continue;
            maxTimestamp = timeStamp;
        }
    }
}

