/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.nn.rating;

import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.structure.MatrixEntry;
import net.librec.recommender.MatrixRecommender;
import net.librec.recommender.nn.rating.AutoRecLossFunction;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.ILossFunction;

@ModelData(value={"isRating", "autorec", "autoRecModel", "trainSet"})
public class AutoRecRecommender
extends MatrixRecommender {
    private int inputDim;
    private int hiddenDim;
    private double learningRate;
    private double momentum;
    private double lambdaReg;
    private int numIterations;
    private String hiddenActivation;
    private String outputActivation;
    private MultiLayerNetwork autoRecModel;
    private INDArray trainSet;
    private INDArray trainSetMask;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        this.inputDim = this.numUsers;
        this.hiddenDim = this.conf.getInt("rec.hidden.dimension");
        this.learningRate = this.conf.getDouble("rec.iterator.learnrate");
        this.lambdaReg = this.conf.getDouble("rec.weight.regularization");
        this.numIterations = this.conf.getInt("rec.iterator.maximum");
        this.hiddenActivation = this.conf.get("rec.hidden.activation");
        this.outputActivation = this.conf.get("rec.output.activation");
        int[] matrixShape = new int[]{this.numItems, this.numUsers};
        this.trainSet = Nd4j.zeros((int[])matrixShape);
        this.trainSetMask = Nd4j.zeros((int[])matrixShape);
        for (MatrixEntry me : this.trainMatrix) {
            this.trainSet.put(me.column(), me.row(), (Number)me.get());
            this.trainSetMask.put(me.column(), me.row(), (Number)1);
        }
    }

    @Override
    protected void trainModel() throws LibrecException {
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().iterations(1).updater(Updater.NESTEROVS).learningRate(this.learningRate).weightInit(WeightInit.XAVIER_UNIFORM).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).regularization(true).l2(this.lambdaReg).list().layer(0, (Layer)((DenseLayer.Builder)((DenseLayer.Builder)((DenseLayer.Builder)((DenseLayer.Builder)new DenseLayer.Builder().nIn(this.inputDim)).nOut(this.hiddenDim)).activation(Activation.fromString((String)this.hiddenActivation))).biasInit(0.1)).build()).layer(1, (Layer)((OutputLayer.Builder)((OutputLayer.Builder)((OutputLayer.Builder)((OutputLayer.Builder)new OutputLayer.Builder((ILossFunction)new AutoRecLossFunction()).nIn(this.hiddenDim)).nOut(this.inputDim)).activation(Activation.fromString((String)this.outputActivation))).biasInit(0.1)).build()).pretrain(false).backprop(true).build();
        this.autoRecModel = new MultiLayerNetwork(conf);
        this.autoRecModel.init();
        for (int iter = 1; iter <= this.numIterations; ++iter) {
            this.loss = 0.0;
            AutoRecLossFunction.trainMask = this.trainSetMask;
            this.autoRecModel.fit(this.trainSet, this.trainSet);
            this.loss = this.autoRecModel.score();
            if (this.isConverged(iter) && this.earlyStop) break;
            this.lastLoss = this.loss;
        }
    }

    @Override
    protected double predict(int userIdx, int itemIdx) throws LibrecException {
        INDArray predictedRatingVector = this.autoRecModel.output(this.trainSet.getRow(itemIdx));
        return predictedRatingVector.getDouble(userIdx);
    }
}

