/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.cf.rating;

import net.librec.common.LibrecException;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.TensorEntry;
import net.librec.math.structure.Vector;
import net.librec.math.structure.VectorBasedDenseVector;
import net.librec.math.structure.VectorBasedSequentialSparseVector;
import net.librec.recommender.FactorizationMachineRecommender;

public class FMFTRLRecommender
extends FactorizationMachineRecommender {
    private double lambda1;
    private double lambda2;
    private double alpha;
    private double beta;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        this.lambda1 = this.conf.getDouble("rec.regularization.lambda1");
        this.lambda2 = this.conf.getDouble("rec.regularization.lambda2");
        this.alpha = this.conf.getDouble("rec.learningRate.alpha");
        this.beta = this.conf.getDouble("rec.learningRate.beta");
    }

    @Override
    protected void trainModel() throws LibrecException {
        if (!this.isRanking) {
            this.buildRatingModel();
        }
    }

    private void buildRatingModel() throws LibrecException {
        double zW0 = 0.0;
        VectorBasedDenseVector zW = new VectorBasedDenseVector(this.p);
        DenseMatrix zV = new DenseMatrix(this.p, this.k);
        zW.init(0.0);
        zV.init(0.0);
        double nW0 = 0.0;
        VectorBasedDenseVector nW = new VectorBasedDenseVector(this.p);
        DenseMatrix nV = new DenseMatrix(this.p, this.k);
        nW.init(0.0);
        nV.init(0.0);
        VectorBasedDenseVector gW = new VectorBasedDenseVector(this.p);
        VectorBasedDenseVector thetaW = new VectorBasedDenseVector(this.p);
        DenseMatrix gV = new DenseMatrix(this.p, this.k);
        DenseMatrix thetaV = new DenseMatrix(this.p, this.k);
        for (int iter = 0; iter < this.numIterations; ++iter) {
            this.loss = 0.0;
            for (TensorEntry me : this.trainTensor) {
                int[] entryKeys = me.keys();
                VectorBasedSequentialSparseVector x = this.tenserKeysToFeatureVector(entryKeys);
                double rate = me.get();
                double pred = this.predict(x);
                double err = pred - rate;
                this.loss += err * err;
                double gradLoss = err;
                double hW0 = 1.0;
                double gW0 = gradLoss * hW0;
                double thetaW0 = 1.0 / this.alpha * (Math.sqrt(nW0 + Math.pow(gW0, 2.0)) - Math.sqrt(nW0));
                this.w0 = Math.abs(zW0) <= this.lambda1 ? 0.0 : -1.0 / ((this.beta + Math.sqrt(nW0 += Math.pow(gW0, 2.0))) / this.alpha + this.lambda2) * ((zW0 += gW0 - thetaW0 * this.w0) - (double)this.sgn(zW0) * this.lambda1);
                for (Vector.VectorEntry ve : x) {
                    int l = ve.index();
                    double hWl = ve.get();
                    gW.set(l, gradLoss * hWl);
                    thetaW.set(l, 1.0 / this.alpha * (Math.sqrt(nW.get(l) + Math.pow(gW.get(l), 2.0)) - Math.sqrt(nW.get(l))));
                    zW.plus(l, gW.get(l) - thetaW.get(l) * this.W.get(l));
                    nW.plus(l, Math.pow(gW.get(l), 2.0));
                    if (Math.abs(zW.get(l)) <= this.lambda1) {
                        this.W.set(l, 0.0);
                    } else {
                        double value = -1.0 / ((this.beta + Math.sqrt(nW.get(l))) / this.alpha + this.lambda2) * (zW.get(l) - (double)this.sgn(zW.get(l)) * this.lambda1);
                        this.W.set(l, value);
                    }
                    for (int f = 0; f < this.k; ++f) {
                        double hVlf = 0.0;
                        double xl = ve.get();
                        for (Vector.VectorEntry ve2 : x) {
                            int j = ve2.index();
                            if (j == l) continue;
                            hVlf += xl * this.V.get(j, f) * ve2.get();
                        }
                        double gradVlf = gradLoss * hVlf;
                        gV.set(l, f, gradVlf);
                        thetaV.set(l, f, 1.0 / this.alpha * (Math.sqrt(nV.get(l, f) + Math.pow(gV.get(l, f), 2.0)) - Math.sqrt(nV.get(l, f))));
                        zV.plus(l, f, gV.get(l, f) - thetaV.get(l, f) * this.V.get(l, f));
                        nV.plus(l, f, Math.pow(gV.get(l, f), 2.0));
                        if (Math.abs(zV.get(l, f)) <= this.lambda1) {
                            this.V.set(l, f, 0.0);
                            continue;
                        }
                        double value = -1.0 / ((this.beta + Math.sqrt(nV.get(l, f))) / this.alpha + this.lambda2) * (zV.get(l, f) - (double)this.sgn(zV.get(l, f)) * this.lambda1);
                        this.V.set(l, f, value);
                    }
                }
            }
            this.loss *= 0.5;
            if (this.isConverged(iter) && this.earlyStop) break;
        }
    }

    private int sgn(double value) {
        return value > 0.0 ? 1 : (value == 0.0 ? 0 : -1);
    }

    @Override
    @Deprecated
    protected double predict(int userIdx, int itemIdx) throws LibrecException {
        return 0.0;
    }
}

