/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.content;

import com.google.common.collect.BiMap;
import java.io.File;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import net.librec.common.LibrecException;
import net.librec.math.structure.DataFrame;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.MatrixEntry;
import net.librec.math.structure.SequentialAccessSparseMatrix;
import net.librec.math.structure.SequentialSparseVector;
import net.librec.math.structure.SparseTensor;
import net.librec.math.structure.TensorEntry;
import net.librec.math.structure.Vector;
import net.librec.math.structure.VectorBasedDenseVector;
import net.librec.recommender.TensorRecommender;
import net.librec.recommender.content.ConvMFDocumentDataSetIterator;
import net.librec.recommender.content.ConvMFDocumentProvider;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.graph.MergeVertex;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.PoolingType;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.lossfunctions.LossFunctions;

public class ConvMFRecommender
extends TensorRecommender {
    protected SequentialAccessSparseMatrix trainMatrix;
    protected DenseMatrix userFactors;
    protected DenseMatrix itemFactors;
    protected float lambda_u;
    protected double lambda_v;
    protected String pretrain_w2v_path;
    protected int w2v_dim;
    protected int max_len;
    protected int featureMapNum;
    public BiMap<Integer, String> reviewMappingData;
    protected Map<Integer, StringBuilder> itemIdx2document;
    protected CNN_Module cnn_module;

    @Override
    protected void setup() throws LibrecException {
        String reviewContentString;
        String reviewContent;
        int reviewIndex;
        int itemIndex;
        int[] entryKeys;
        super.setup();
        this.reviewMappingData = DataFrame.getInnerMapping("review").inverse();
        this.lambda_u = this.conf.getFloat("rec.user.regularization", Float.valueOf(0.1f)).floatValue();
        this.lambda_v = this.conf.getFloat("rec.item.regularization", Float.valueOf(0.1f)).floatValue();
        this.trainTensor = (SparseTensor)this.getDataModel().getTrainDataSet();
        this.userFactors = new DenseMatrix(this.numUsers, this.numFactors);
        this.itemFactors = new DenseMatrix(this.numItems, this.numFactors);
        this.userFactors.init(1.0);
        this.itemFactors.init(1.0);
        this.pretrain_w2v_path = this.conf.get("dfs.data.dir") + "/" + this.conf.get("rec.word2vec.path");
        this.w2v_dim = this.conf.getInt("rec.word2vec.dimension");
        this.max_len = this.conf.getInt("rec.document.length");
        this.featureMapNum = this.conf.getInt("rec.featuremap.num");
        this.trainMatrix = this.trainTensor.rateMatrix();
        this.itemIdx2document = new HashMap<Integer, StringBuilder>();
        for (int i = 0; i < this.numItems; ++i) {
            this.itemIdx2document.put(i, new StringBuilder());
        }
        for (TensorEntry te : this.trainTensor) {
            entryKeys = te.keys();
            itemIndex = entryKeys[1];
            reviewIndex = entryKeys[2];
            reviewContent = (String)this.reviewMappingData.get(reviewIndex);
            reviewContentString = reviewContent.replaceAll(":", " ").replaceAll("#", " ");
            this.itemIdx2document.get(itemIndex).append(reviewContentString).append(".");
        }
        for (TensorEntry te : this.testTensor) {
            entryKeys = te.keys();
            itemIndex = entryKeys[1];
            reviewIndex = entryKeys[2];
            reviewContent = (String)this.reviewMappingData.get(reviewIndex);
            reviewContentString = reviewContent.replaceAll(":", " ").replaceAll("#", " ");
            this.itemIdx2document.get(itemIndex).append(reviewContentString).append(".");
        }
        this.cnn_module = new CNN_Module();
    }

    @Override
    protected void trainModel() throws LibrecException {
        DenseMatrix identify = new DenseMatrix(this.numFactors, this.numFactors);
        identify.init(0.0);
        for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
            identify.set(factorIdx, factorIdx, 1.0);
        }
        for (int iter = 1; iter <= this.numIterations; ++iter) {
            double realRating;
            Object A;
            int index;
            this.loss = 0.0;
            this.cnn_module.trainCNN();
            for (int userIdx = 0; userIdx < this.numUsers; ++userIdx) {
                int u_numItems = this.trainMatrix.row(userIdx).getNumEntries();
                DenseMatrix M = new DenseMatrix(u_numItems, this.numFactors);
                int[] itemIdices = this.trainMatrix.row(userIdx).getIndices();
                SequentialSparseVector userVec = this.trainMatrix.row(userIdx);
                index = 0;
                for (int itemIdx : itemIdices) {
                    M.set(index++, this.itemFactors.row(itemIdx));
                }
                A = M.transpose().times(M).plus(identify.times(this.lambda_u).times(u_numItems));
                int index1 = 0;
                VectorBasedDenseVector userVector = new VectorBasedDenseVector(u_numItems);
                for (Vector.VectorEntry ve : userVec) {
                    realRating = ve.get();
                    userVector.set(index1++, realRating);
                }
                this.userFactors.set(userIdx, ((DenseMatrix)A).inverse().times(M.transpose().times(userVector)));
            }
            for (int itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
                int i_numUsers = this.trainMatrix.column(itemIdx).getNumEntries();
                DenseMatrix U = new DenseMatrix(i_numUsers, this.numFactors);
                int[] userIdices = this.trainMatrix.column(itemIdx).getIndices();
                SequentialSparseVector itemVec = this.trainMatrix.column(itemIdx);
                index = 0;
                for (int userIdx : userIdices) {
                    U.set(index++, this.userFactors.row(userIdx));
                }
                if (i_numUsers == 0) continue;
                A = U.transpose().times(U).plus(identify.times(this.lambda_v).times(i_numUsers));
                VectorBasedDenseVector itemVector = new VectorBasedDenseVector(i_numUsers);
                int index1 = 0;
                for (Vector.VectorEntry ve : itemVec) {
                    realRating = ve.get();
                    itemVector.set(index1++, realRating);
                }
                this.itemFactors.set(itemIdx, ((DenseMatrix)A).inverse().times(U.transpose().times(itemVector).plus(this.cnn_module.getOutput(this.itemIdx2document.get(itemIdx).toString()).times(this.lambda_v))));
            }
            for (MatrixEntry me : this.trainMatrix) {
                int userIdx = me.row();
                int itemIdx = me.column();
                double error = me.get() - this.userFactors.row(userIdx).dot(this.itemFactors.row(itemIdx));
                this.loss += error * error;
            }
            this.LOG.info("iter: " + iter + ", loss: " + this.loss);
        }
    }

    @Override
    protected double predict(int[] indices) {
        return this.predict(indices[0], indices[1]);
    }

    @Override
    protected double predict(int userIdx, int itemIdx) {
        return this.userFactors.row(userIdx).dot(this.itemFactors.row(itemIdx));
    }

    private class CNN_Module {
        ComputationGraph net;
        ConvMFDocumentDataSetIterator trainIter;
        int nEpochs = 5;
        int batchSize = 32;
        int vectorSize;
        int truncateReviewsToLength;
        int cnnLayerFeatureMaps;

        CNN_Module() {
            this.vectorSize = ConvMFRecommender.this.w2v_dim;
            this.truncateReviewsToLength = ConvMFRecommender.this.max_len;
            this.cnnLayerFeatureMaps = ConvMFRecommender.this.featureMapNum;
            PoolingType globalPoolingType = PoolingType.MAX;
            ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder().weightInit(WeightInit.RELU).activation(Activation.LEAKYRELU).updater(Updater.ADADELTA).convolutionMode(ConvolutionMode.Same).regularization(true).dropOut(0.2).learningRate((double)ConvMFRecommender.this.learnRate).graphBuilder().addInputs(new String[]{"input"}).addLayer("cnn3", (Layer)new ConvolutionLayer.Builder().kernelSize(new int[]{3, this.vectorSize}).stride(new int[]{1, this.vectorSize}).nIn(1).nOut(this.cnnLayerFeatureMaps).build(), new String[]{"input"}).addLayer("cnn4", (Layer)new ConvolutionLayer.Builder().kernelSize(new int[]{4, this.vectorSize}).stride(new int[]{1, this.vectorSize}).nIn(1).nOut(this.cnnLayerFeatureMaps).build(), new String[]{"input"}).addLayer("cnn5", (Layer)new ConvolutionLayer.Builder().kernelSize(new int[]{5, this.vectorSize}).stride(new int[]{1, this.vectorSize}).nIn(1).nOut(this.cnnLayerFeatureMaps).build(), new String[]{"input"}).addVertex("merge", (GraphVertex)new MergeVertex(), new String[]{"cnn3", "cnn4", "cnn5"}).addLayer("globalPool", (Layer)new GlobalPoolingLayer.Builder().poolingType(globalPoolingType).build(), new String[]{"merge"}).addLayer("out", (Layer)((OutputLayer.Builder)((OutputLayer.Builder)((OutputLayer.Builder)((OutputLayer.Builder)new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE)).activation(Activation.RELU)).nIn(3 * this.cnnLayerFeatureMaps)).nOut(ConvMFRecommender.this.numFactors)).build(), new String[]{"globalPool"}).setOutputs(new String[]{"out"}).build();
            this.net = new ComputationGraph(config);
            this.net.init();
            System.out.println("Loading word vectors and creating DataSetIterators");
            WordVectors wordVectors = WordVectorSerializer.loadStaticModel((File)new File(ConvMFRecommender.this.pretrain_w2v_path));
            this.trainIter = this.getDataSetIterator(wordVectors, this.batchSize, this.truncateReviewsToLength);
        }

        void trainCNN() {
            for (int i = 0; i < this.nEpochs; ++i) {
                this.net.fit((DataSetIterator)this.trainIter);
            }
        }

        VectorBasedDenseVector getOutput(String inputDocument) {
            INDArray predictions = this.net.outputSingle(new INDArray[]{this.trainIter.loadSingleSentence(inputDocument)});
            VectorBasedDenseVector outputVec = new VectorBasedDenseVector(ConvMFRecommender.this.numFactors);
            for (int i = 0; i < ConvMFRecommender.this.numFactors; ++i) {
                outputVec.set(i, predictions.getDouble(i));
            }
            return outputVec;
        }

        private ConvMFDocumentDataSetIterator getDataSetIterator(WordVectors wordVectors, int minibatchSize, int maxSentenceLength) {
            ArrayList<String> documents = new ArrayList<String>();
            for (int i = 0; i < ConvMFRecommender.this.numItems; ++i) {
                documents.add(ConvMFRecommender.this.itemIdx2document.get(i).toString());
            }
            ArrayList<double[]> labelsForDocuments = new ArrayList<double[]>();
            for (int i = 0; i < ConvMFRecommender.this.numItems; ++i) {
                labelsForDocuments.add(ConvMFRecommender.this.itemFactors.row(i).getValues());
            }
            ConvMFDocumentProvider documentProvider = new ConvMFDocumentProvider(documents, labelsForDocuments);
            return new ConvMFDocumentDataSetIterator.Builder().documentProvider(documentProvider).wordVectors(wordVectors).minibatchSize(minibatchSize).maxSentenceLength(maxSentenceLength).useNormalizedWordVectors(false).build();
        }
    }
}

