/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.cf.ranking;

import com.google.common.cache.LoadingCache;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;
import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.DenseVector;
import net.librec.math.structure.SequentialSparseVector;
import net.librec.math.structure.Vector;
import net.librec.math.structure.VectorBasedDenseVector;
import net.librec.recommender.MatrixFactorizationRecommender;

@ModelData(value={"isRanking", "fismauc", "P", "Q", "itemBiases", "userBiases"})
public class FISMaucRecommender
extends MatrixFactorizationRecommender {
    private float rho;
    private float alpha;
    private float beta;
    private float gamma;
    private double lRate;
    private VectorBasedDenseVector itemBiases;
    private DenseMatrix P;
    private DenseMatrix Q;
    protected LoadingCache<Integer, List<Integer>> userItemsCache;
    protected static String cacheSpec;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        this.P = new DenseMatrix(this.numItems, this.numFactors);
        this.Q = new DenseMatrix(this.numItems, this.numFactors);
        this.P.init(0.0, 0.01);
        this.Q.init(0.0, 0.01);
        this.itemBiases = new VectorBasedDenseVector(this.numItems);
        this.itemBiases.init(0.0, 0.01);
        this.rho = this.conf.getFloat("rec.recommender.rho").floatValue();
        this.alpha = this.conf.getFloat("rec.recommender.alpha", Float.valueOf(0.5f)).floatValue();
        this.beta = this.conf.getFloat("rec.recommender.beta", Float.valueOf(0.6f)).floatValue();
        this.gamma = this.conf.getFloat("rec.recommender.gamma", Float.valueOf(0.1f)).floatValue();
        this.lRate = this.conf.getDouble("rec.iteration.learnrate", 1.0E-4);
        cacheSpec = this.conf.get("guava.cache.spec", "maximumSize=200,expireAfterAccess=2m");
        this.userItemsCache = this.trainMatrix.rowColumnsCache(cacheSpec);
    }

    @Override
    protected void trainModel() throws LibrecException {
        for (int iter = 1; iter <= this.numIterations; ++iter) {
            this.loss = 0.0;
            for (int u = 0; u < this.numUsers; ++u) {
                SequentialSparseVector Ru = this.trainMatrix.row(u);
                Set u_items = Arrays.stream(Ru.getIndices()).boxed().collect(Collectors.toSet());
                int Ru_p_size = Ru.getNumEntries();
                if (Ru_p_size == 0 || Ru_p_size == 1) {
                    Ru_p_size = 2;
                }
                for (Vector.VectorEntry ve : Ru) {
                    int i = ve.index();
                    DenseVector x = new VectorBasedDenseVector(this.numFactors);
                    x.init(0.0);
                    DenseVector t = new VectorBasedDenseVector(this.numFactors);
                    t.init(0.0);
                    Iterator iterator = u_items.iterator();
                    while (iterator.hasNext()) {
                        int j = (Integer)iterator.next();
                        if (i == j) continue;
                        t = t.plus(this.P.row(j));
                    }
                    t = t.times(Math.pow(Ru_p_size - 1, -this.alpha));
                    int sampleSize = (int)(this.rho * (float)Ru_p_size);
                    List<Integer> negative_indices = null;
                    try {
                        negative_indices = Randoms.randInts(sampleSize, 0, this.numItems);
                        Iterator<Integer> iterator2 = negative_indices.iterator();
                        while (iterator2.hasNext()) {
                            int index = iterator2.next();
                            if (!u_items.contains(index)) continue;
                            iterator2.remove();
                        }
                    }
                    catch (Exception e) {
                        e.printStackTrace();
                    }
                    for (int j : negative_indices) {
                        double bi = this.itemBiases.get(i);
                        double bj = this.itemBiases.get(j);
                        double rui = ve.get();
                        double pui = bi + this.Q.row(i).dot(t);
                        double puj = bj + this.Q.row(j).dot(t);
                        double ruj = 0.0;
                        double e = rui - ruj - (pui - puj);
                        this.loss += e * e;
                        this.itemBiases.plus(i, this.lRate * (e - (double)this.gamma * bi));
                        this.itemBiases.plus(j, this.lRate * (e - (double)this.gamma * bj));
                        DenseVector delta_qi = t.times(e).minus(this.Q.row(i).times(this.beta));
                        DenseVector qi = this.Q.row(i).plus(delta_qi.times(this.lRate));
                        this.Q.set(i, qi);
                        DenseVector delta_qj = t.times(e).minus(this.Q.row(j).times(this.beta));
                        DenseVector qj = this.Q.row(j).minus(delta_qj.times(this.lRate));
                        this.Q.set(j, qj);
                        x = x.plus(qi.minus(qj).times(e));
                    }
                    for (int j : u_items) {
                        if (j == i) continue;
                        DenseVector delta_pj = x.times(Math.pow(this.rho, -1.0) * Math.pow(Ru_p_size - 1, -this.alpha)).minus(this.P.row(j).times(this.beta));
                        this.P.set(j, this.P.row(j).plus(delta_pj.times(this.lRate)));
                    }
                }
            }
            for (int i = 0; i < this.numItems; ++i) {
                double bi = this.itemBiases.get(i);
                this.loss += (double)this.gamma * bi * bi;
                this.loss += (double)this.beta * this.Q.row(i).dot(this.Q.row(i));
                this.loss += (double)this.beta * this.P.row(i).dot(this.P.row(i));
            }
            this.loss *= 0.5;
            if (this.isConverged(iter) && this.earlyStop) break;
            this.updateLRate(iter);
        }
    }

    @Override
    protected double predict(int u, int j) throws LibrecException {
        double pred = this.itemBiases.get(j);
        double sum2 = 0.0;
        int count = 0;
        List<Integer> ratedItems = null;
        try {
            ratedItems = this.userItemsCache.get(u);
        }
        catch (ExecutionException e) {
            e.printStackTrace();
        }
        Iterator e = ratedItems.iterator();
        while (e.hasNext()) {
            int i = (Integer)e.next();
            if (i == j) continue;
            sum2 += this.P.row(i).dot(this.Q.row(j));
            ++count;
        }
        double wu = count - 1 > 0 ? Math.pow(count - 1, -this.alpha) : 0.0;
        return pred + wu * sum2;
    }
}

