/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.cf.ranking;

import com.google.common.cache.LoadingCache;
import com.google.common.collect.Table;
import java.util.List;
import java.util.concurrent.ExecutionException;
import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.DenseVector;
import net.librec.math.structure.MatrixEntry;
import net.librec.math.structure.SequentialSparseVector;
import net.librec.math.structure.VectorBasedDenseVector;
import net.librec.recommender.MatrixFactorizationRecommender;

@ModelData(value={"isRanking", "fismrmse", "P", "Q", "itemBiases", "userBiases"})
public class FISMrmseRecommender
extends MatrixFactorizationRecommender {
    private int nnz;
    private float rho;
    private float alpha;
    private float beta;
    private float itemBiasReg;
    private float userBiasReg;
    private double lRate;
    private VectorBasedDenseVector itemBiases;
    private VectorBasedDenseVector userBiases;
    private DenseMatrix P;
    private DenseMatrix Q;
    protected LoadingCache<Integer, List<Integer>> userItemsCache;
    protected static String cacheSpec;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        this.P = new DenseMatrix(this.numItems, this.numFactors);
        this.Q = new DenseMatrix(this.numItems, this.numFactors);
        this.P.init(0.0, 0.01);
        this.Q.init(0.0, 0.01);
        this.userBiases = new VectorBasedDenseVector(this.numUsers);
        this.itemBiases = new VectorBasedDenseVector(this.numItems);
        this.userBiases.init(0.0, 0.01);
        this.itemBiases.init(0.0, 0.01);
        this.nnz = this.trainMatrix.size();
        this.rho = this.conf.getFloat("rec.recommender.rho").floatValue();
        this.alpha = this.conf.getFloat("rec.recommender.alpha", Float.valueOf(0.5f)).floatValue();
        this.beta = this.conf.getFloat("rec.recommender.beta", Float.valueOf(0.6f)).floatValue();
        this.itemBiasReg = this.conf.getFloat("rec.recommender.itemBiasReg", Float.valueOf(0.1f)).floatValue();
        this.userBiasReg = this.conf.getFloat("rec.recommender.userBiasReg", Float.valueOf(0.1f)).floatValue();
        this.lRate = this.conf.getDouble("rec.iteration.learnrate", 1.0E-4);
        cacheSpec = this.conf.get("guava.cache.spec", "maximumSize=200,expireAfterAccess=2m");
        this.userItemsCache = this.trainMatrix.rowColumnsCache(cacheSpec);
    }

    @Override
    protected void trainModel() throws LibrecException {
        int sampleSize = (int)(this.rho * (float)this.nnz);
        int totalSize = this.numUsers * this.numItems;
        for (int iter = 1; iter <= this.numIterations; ++iter) {
            int u;
            this.loss = 0.0;
            Table<Integer, Integer, Double> R = this.trainMatrix.getDataTable();
            List<Integer> indices = null;
            try {
                indices = Randoms.randInts(sampleSize, 0, totalSize - this.nnz);
            }
            catch (Exception e) {
                e.printStackTrace();
            }
            int index = 0;
            int count = 0;
            for (MatrixEntry matrixEntry : this.trainMatrix) {
                u = matrixEntry.row();
                int j = matrixEntry.column();
                double ruj = matrixEntry.get();
                if (ruj != 0.0 || count++ != indices.get(index)) continue;
                R.put(u, j, 0.0);
                if (++index < indices.size()) continue;
                break;
            }
            for (Table.Cell cell : R.cellSet()) {
                u = (Integer)cell.getRowKey();
                int i = (Integer)cell.getColumnKey();
                double rui = (Double)cell.getValue();
                SequentialSparseVector Ru = this.trainMatrix.row(u);
                int n_u = Ru.size() - 1;
                if (n_u == 0 || n_u == -1) {
                    n_u = 1;
                }
                DenseVector X = new VectorBasedDenseVector(this.numFactors);
                for (int j : Ru.getIndices()) {
                    if (i == j) continue;
                    X = X.plus(this.P.row(j));
                }
                X = X.times(Math.pow(n_u, -this.alpha));
                double bi = this.itemBiases.get(i);
                double bu = this.userBiases.get(u);
                double pui = bu + bi + this.Q.row(i).dot(X);
                double eui = rui - pui;
                this.loss += eui * eui;
                this.itemBiases.plus(i, this.lRate * (eui - (double)this.itemBiasReg * bi));
                this.loss += (double)this.itemBiasReg * bi * bi;
                this.userBiases.plus(u, this.lRate * (eui - (double)this.userBiasReg * bu));
                this.loss += (double)this.itemBiasReg * bu * bu;
                DenseVector deltaq = X.times(eui).minus(this.Q.row(i).times(this.beta));
                this.loss += (double)this.beta * this.Q.row(i).dot(this.Q.row(i));
                this.Q.set(i, this.Q.row(i).plus(deltaq.times(this.lRate)));
                for (int j : Ru.getIndices()) {
                    if (i == j) continue;
                    DenseVector deltap = this.Q.row(i).times(eui * Math.pow(n_u, -this.alpha)).minus(this.P.row(j).times(this.beta));
                    this.loss += (double)this.beta * this.P.row(j).dot(this.P.row(j));
                    this.P.set(j, this.P.row(j).plus(deltap.times(this.lRate)));
                }
            }
            this.loss *= 0.5;
            if (this.isConverged(iter) && this.earlyStop) break;
            this.updateLRate(iter);
        }
    }

    @Override
    protected double predict(int u, int j) throws LibrecException {
        double pred = this.userBiases.get(u) + this.itemBiases.get(j);
        List<Integer> ratedItems = null;
        try {
            ratedItems = this.userItemsCache.get(u);
        }
        catch (ExecutionException e) {
            e.printStackTrace();
        }
        double sum2 = 0.0;
        int count = 0;
        for (int i : ratedItems) {
            if (i == j) continue;
            sum2 += this.P.row(i).dot(this.Q.row(j));
            ++count;
        }
        double wu = count - 1 > 0 ? Math.pow(count - 1, -this.alpha) : 0.0;
        return pred + wu * sum2;
    }
}

