/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.cf.ranking;

import com.google.common.collect.BiMap;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import net.librec.common.LibrecException;
import net.librec.math.structure.SequentialSparseVector;
import net.librec.recommender.MatrixFactorizationRecommender;

public class NMFItemItemRecommender
extends MatrixFactorizationRecommender {
    private double[][] w_reconstruct;
    private double[][] h_analyze;
    private int numFactors;
    private int numIterations;
    private double divergenceFromLastStep;
    private double exponent = 0.5;
    private int parallelizeSplitUserSize = 5000;
    private boolean doNotEstimateYourself = true;
    private boolean adaptiveUpdateRules = true;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        this.numFactors = this.conf.getInt("rec.factor.number", 15);
        this.numIterations = this.conf.getInt("rec.iterator.maximum", 100);
        this.doNotEstimateYourself = this.conf.getBoolean("rec.nmfitemitem.do_not_estimate_yourself", true);
        this.adaptiveUpdateRules = this.conf.getBoolean("rec.nmfitemitem.adaptive_update_rules", true);
        this.parallelizeSplitUserSize = this.conf.getInt("rec.nmfitemitem.parallelize_split_user_size", 5000);
        this.w_reconstruct = new double[this.numFactors][this.numItems];
        this.h_analyze = new double[this.numFactors][this.numItems];
        this.initMatrix(this.w_reconstruct);
        this.initMatrix(this.h_analyze);
    }

    private void initMatrix(double[][] m) {
        double initValue = 1.0 / ((double)this.numItems * 2.0);
        Random random = new Random(123456789L);
        for (int i = 0; i < m.length; ++i) {
            for (int j = 0; j < m[i].length; ++j) {
                m[i][j] = (random.nextDouble() + 1.0) * initValue;
            }
        }
    }

    @Override
    public void trainModel() {
        int availableProcessors = Runtime.getRuntime().availableProcessors();
        this.LOG.info("availableProcessors=" + availableProcessors);
        ExecutorService executorService = Executors.newFixedThreadPool(availableProcessors);
        for (int iter = 0; iter <= this.numIterations; ++iter) {
            this.LOG.info("Starting iteration=" + iter);
            this.train(executorService, iter);
        }
        executorService.shutdown();
    }

    private void train(ExecutorService executorService, int iteration) {
        ArrayList<ParallelExecTask> tasks = new ArrayList<ParallelExecTask>(this.numUsers / this.parallelizeSplitUserSize + 1);
        for (int fromUser = 0; fromUser < this.numUsers; fromUser += this.parallelizeSplitUserSize) {
            int toUserExclusive = Math.min(this.numUsers, fromUser + this.parallelizeSplitUserSize);
            ParallelExecTask task = new ParallelExecTask(fromUser, toUserExclusive);
            tasks.add(task);
        }
        try {
            int itemIdx;
            List results = executorService.invokeAll(tasks);
            double[][] resultNumeratorAnalyze = new double[this.numFactors][this.numItems];
            double[][] resultNumeratorReconstruct = new double[this.numFactors][this.numItems];
            double[][] resultDenominatorReconstructDiff = new double[this.numFactors][this.numItems];
            double[] resultDenominatorReconstruct2 = new double[this.numFactors];
            int boughtItems = 0;
            double sumLog = 0.0;
            int[] countUsersBoughtItem = new int[this.numItems];
            for (Future future : results) {
                AggResult result = (AggResult)future.get();
                for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                    for (itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
                        double[] dArray = resultNumeratorAnalyze[factorIdx];
                        int n = itemIdx;
                        dArray[n] = dArray[n] + result.resultNumeratorAnalyze[factorIdx][itemIdx];
                        double[] dArray2 = resultNumeratorReconstruct[factorIdx];
                        int n2 = itemIdx;
                        dArray2[n2] = dArray2[n2] + result.resultNumeratorReconstruct[factorIdx][itemIdx];
                        double[] dArray3 = resultDenominatorReconstructDiff[factorIdx];
                        int n3 = itemIdx;
                        dArray3[n3] = dArray3[n3] + result.resultDenominatorReconstructDiff[factorIdx][itemIdx];
                    }
                    int n = factorIdx;
                    resultDenominatorReconstruct2[n] = resultDenominatorReconstruct2[n] + result.resultDenominatorReconstruct2[factorIdx];
                }
                for (int itemIdx2 = 0; itemIdx2 < this.numItems; ++itemIdx2) {
                    int n = itemIdx2;
                    countUsersBoughtItem[n] = countUsersBoughtItem[n] + result.countUsersBoughtItem[itemIdx2];
                }
                boughtItems += result.boughtItems;
                sumLog += result.sumLog;
            }
            double[] wNorm = new double[this.numFactors];
            for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                double sum2 = 0.0;
                for (itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
                    sum2 += this.w_reconstruct[factorIdx][itemIdx];
                }
                wNorm[factorIdx] = sum2;
            }
            double divergence = this.calculateDivergence(boughtItems, sumLog, iteration, resultDenominatorReconstruct2, resultDenominatorReconstructDiff);
            if (this.adaptiveUpdateRules) {
                if (iteration == 0 || divergence > this.divergenceFromLastStep) {
                    this.LOG.info("divergence > divergenceFromLastStep. Setting exponent to 0.5.");
                    this.exponent = 0.5;
                } else {
                    if (this.exponent < 1.45) {
                        this.exponent += 0.1;
                    }
                    this.LOG.info("divergence <= divergenceFromLastStep. Exponent is now: " + this.exponent);
                }
                this.divergenceFromLastStep = divergence;
            }
            double[][] new_w_reconstruct = this.updateReconstruct(resultNumeratorReconstruct, resultDenominatorReconstructDiff, resultDenominatorReconstruct2);
            this.updateAnalyze(resultNumeratorAnalyze, wNorm, countUsersBoughtItem);
            this.w_reconstruct = new_w_reconstruct;
        }
        catch (InterruptedException | ExecutionException e) {
            this.LOG.error("", e);
            throw new IllegalStateException(e);
        }
    }

    private void updateAnalyze(double[][] resultNumeratorAnalyze, double[] wNorm, int[] countUsersBoughtItem) {
        for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
            for (int itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
                double oldValue = this.h_analyze[factorIdx][itemIdx];
                double numerator = resultNumeratorAnalyze[factorIdx][itemIdx];
                double denominator = this.doNotEstimateYourself ? (double)countUsersBoughtItem[itemIdx] * (wNorm[factorIdx] - this.w_reconstruct[factorIdx][itemIdx]) : (double)countUsersBoughtItem[itemIdx] * wNorm[factorIdx];
                double newValue = oldValue * Math.pow(numerator / denominator, this.exponent);
                if (Double.isNaN(newValue)) {
                    newValue = 0.0;
                }
                this.h_analyze[factorIdx][itemIdx] = newValue;
            }
        }
    }

    private double[][] updateReconstruct(double[][] resultNumeratorReconstruct, double[][] resultDenominatorReconstructDiff, double[] resultDenominatorReconstruct2) {
        double[][] new_w_reconstruct = new double[this.numFactors][this.numItems];
        for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
            for (int itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
                double newValue;
                double oldValue = this.w_reconstruct[factorIdx][itemIdx];
                double numerator = resultNumeratorReconstruct[factorIdx][itemIdx];
                double denominatorDiff = resultDenominatorReconstructDiff[factorIdx][itemIdx];
                double denominator = resultDenominatorReconstruct2[factorIdx];
                new_w_reconstruct[factorIdx][itemIdx] = newValue = oldValue * Math.pow(numerator / (denominator - denominatorDiff), this.exponent);
            }
        }
        return new_w_reconstruct;
    }

    private double calculateDivergence(int countAll, double sumLog, int iteration, double[] resultDenominatorReconstruct, double[][] resultDenominatorReconstructDiff) {
        double sumAllEstimate = 0.0;
        for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
            for (int itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
                double denominatorDiff = resultDenominatorReconstructDiff[factorIdx][itemIdx];
                double denominator = resultDenominatorReconstruct[factorIdx];
                double newValue = denominator - denominatorDiff;
                sumAllEstimate += this.w_reconstruct[factorIdx][itemIdx] * newValue;
            }
        }
        double divergence = sumLog - (double)countAll + sumAllEstimate;
        this.LOG.info("Divergence (before iteration " + iteration + ")=" + divergence + "  sumLog=" + sumLog + "  countAll=" + countAll + "  sumAllEstimate=" + sumAllEstimate);
        return divergence;
    }

    private double predict(SequentialSparseVector itemRatingsVector, int itemIdx) {
        double sum2 = 0.0;
        for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
            sum2 += this.w_reconstruct[factorIdx][itemIdx] * this.predictFactor(itemRatingsVector, factorIdx);
        }
        return sum2;
    }

    private double predictFactor(SequentialSparseVector itemRatingsVector, int factorIdx) {
        double sum2 = 0.0;
        for (int itemIdx : itemRatingsVector.getIndices()) {
            sum2 += this.w_reconstruct[factorIdx][itemIdx];
        }
        return sum2;
    }

    private double[] predictFactors(int[] itemIndices) {
        double[] latentFactors = new double[this.numFactors];
        for (int itemIdx : itemIndices) {
            for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                int n = factorIdx;
                latentFactors[n] = latentFactors[n] + this.h_analyze[factorIdx][itemIdx];
            }
        }
        return latentFactors;
    }

    @Override
    protected double predict(int userIdx, int itemIdx) throws LibrecException {
        SequentialSparseVector itemRatingsVector = this.trainMatrix.row(userIdx);
        return this.predict(itemRatingsVector, itemIdx);
    }

    @Override
    public void saveModel(String directoryPath) {
        File dir = new File(directoryPath);
        dir.mkdir();
        try {
            File wFile = new File(dir, "w_reconstruct.csv");
            this.LOG.info("Writing matrix w_reconstruct to file=" + wFile.getAbsolutePath());
            this.saveMatrix(wFile, this.w_reconstruct);
            File hFile = new File(dir, "h_analyze.csv");
            this.LOG.info("Writing matrix h_analyze to file=" + hFile.getAbsolutePath());
            this.saveMatrix(hFile, this.h_analyze);
        }
        catch (Exception e) {
            this.LOG.error("Could not save model", e);
        }
    }

    private void saveMatrix(File file, double[][] matrix) throws IOException {
        BufferedWriter writer = new BufferedWriter(new FileWriter(file));
        writer.write("\"item_id\"");
        for (int i = 0; i < this.numFactors; ++i) {
            writer.write(44);
            writer.write("\"factor");
            writer.write(Integer.toString(i));
            writer.write("\"");
        }
        writer.write("\r\n");
        BiMap items = this.itemMappingData.inverse();
        for (int itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
            writer.write(34);
            writer.write((String)items.get(itemIdx));
            writer.write(34);
            for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                writer.write(44);
                writer.write(Double.toString(matrix[factorIdx][itemIdx]));
            }
            writer.write("\r\n");
        }
        writer.close();
    }

    private class ParallelExecTask
    implements Callable<AggResult> {
        private final int fromUser;
        private final int toUser;

        public ParallelExecTask(int fromUser, int toUser) {
            this.fromUser = fromUser;
            this.toUser = toUser;
        }

        @Override
        public AggResult call() throws Exception {
            double[][] resultNumeratorAnalyze = new double[NMFItemItemRecommender.this.numFactors][NMFItemItemRecommender.this.numItems];
            double[][] resultNumeratorReconstruct = new double[NMFItemItemRecommender.this.numFactors][NMFItemItemRecommender.this.numItems];
            double[][] resultDenominatorReconstructDiff = new double[NMFItemItemRecommender.this.numFactors][NMFItemItemRecommender.this.numItems];
            double[] resultDenominatorReconstruct = new double[NMFItemItemRecommender.this.numFactors];
            int boughtItems = 0;
            double sumLog = 0.0;
            int[] countUsersBoughtItem = new int[NMFItemItemRecommender.this.numItems];
            for (int userIdx = this.fromUser; userIdx < this.toUser; ++userIdx) {
                int minCount;
                SequentialSparseVector itemRatingsVector = NMFItemItemRecommender.this.trainMatrix.row(userIdx);
                int n = minCount = NMFItemItemRecommender.this.doNotEstimateYourself ? 2 : 1;
                if (itemRatingsVector.getNumEntries() < minCount) continue;
                int[] itemIndices = itemRatingsVector.getIndices();
                double[] allUserLatentFactors = NMFItemItemRecommender.this.predictFactors(itemIndices);
                for (int factorIdx = 0; factorIdx < NMFItemItemRecommender.this.numFactors; ++factorIdx) {
                    int n2 = factorIdx;
                    resultDenominatorReconstruct[n2] = resultDenominatorReconstruct[n2] + allUserLatentFactors[factorIdx];
                }
                double[] analyze_numerator = new double[NMFItemItemRecommender.this.numFactors];
                for (int itemIdx : itemIndices) {
                    double[] thisUserLatentFactors = new double[NMFItemItemRecommender.this.numFactors];
                    for (int factorIdx = 0; factorIdx < NMFItemItemRecommender.this.numFactors; ++factorIdx) {
                        if (NMFItemItemRecommender.this.doNotEstimateYourself) {
                            double[] dArray = resultDenominatorReconstructDiff[factorIdx];
                            int n3 = itemIdx;
                            dArray[n3] = dArray[n3] + NMFItemItemRecommender.this.h_analyze[factorIdx][itemIdx];
                            thisUserLatentFactors[factorIdx] = allUserLatentFactors[factorIdx] - NMFItemItemRecommender.this.h_analyze[factorIdx][itemIdx];
                            continue;
                        }
                        thisUserLatentFactors[factorIdx] = allUserLatentFactors[factorIdx];
                    }
                    double estimate = 0.0;
                    for (int factorIdx = 0; factorIdx < NMFItemItemRecommender.this.numFactors; ++factorIdx) {
                        estimate += thisUserLatentFactors[factorIdx] * NMFItemItemRecommender.this.w_reconstruct[factorIdx][itemIdx];
                    }
                    double estimateFactor = 1.0 / estimate;
                    sumLog += Math.log(estimateFactor);
                    ++boughtItems;
                    int n4 = itemIdx;
                    countUsersBoughtItem[n4] = countUsersBoughtItem[n4] + 1;
                    for (int factorIdx = 0; factorIdx < NMFItemItemRecommender.this.numFactors; ++factorIdx) {
                        double latent = thisUserLatentFactors[factorIdx];
                        double[] dArray = resultNumeratorReconstruct[factorIdx];
                        int n5 = itemIdx;
                        dArray[n5] = dArray[n5] + estimateFactor * latent;
                        double numerator = estimateFactor * NMFItemItemRecommender.this.w_reconstruct[factorIdx][itemIdx];
                        int n6 = factorIdx;
                        analyze_numerator[n6] = analyze_numerator[n6] + numerator;
                        if (!NMFItemItemRecommender.this.doNotEstimateYourself) continue;
                        double[] dArray2 = resultNumeratorAnalyze[factorIdx];
                        int n7 = itemIdx;
                        dArray2[n7] = dArray2[n7] - numerator;
                    }
                }
                for (int lItemIdx : itemIndices) {
                    for (int factorIdx = 0; factorIdx < NMFItemItemRecommender.this.numFactors; ++factorIdx) {
                        double[] dArray = resultNumeratorAnalyze[factorIdx];
                        int n8 = lItemIdx;
                        dArray[n8] = dArray[n8] + analyze_numerator[factorIdx];
                    }
                }
            }
            return new AggResult(resultNumeratorAnalyze, resultNumeratorReconstruct, resultDenominatorReconstructDiff, boughtItems, sumLog, countUsersBoughtItem, resultDenominatorReconstruct);
        }
    }

    private static class AggResult {
        private final double[][] resultNumeratorAnalyze;
        private final double[][] resultNumeratorReconstruct;
        private final double[][] resultDenominatorReconstructDiff;
        private final int boughtItems;
        private final double sumLog;
        private final int[] countUsersBoughtItem;
        private final double[] resultDenominatorReconstruct2;

        public AggResult(double[][] resultNumeratorAnalyze, double[][] resultNumeratorReconstruct, double[][] resultDenominatorReconstructDiff, int boughtItems, double sumLog, int[] countUsersBoughtItem, double[] resultDenominatorReconstruct2) {
            this.resultNumeratorAnalyze = resultNumeratorAnalyze;
            this.resultNumeratorReconstruct = resultNumeratorReconstruct;
            this.resultDenominatorReconstructDiff = resultDenominatorReconstructDiff;
            this.boughtItems = boughtItems;
            this.sumLog = sumLog;
            this.countUsersBoughtItem = countUsersBoughtItem;
            this.resultDenominatorReconstruct2 = resultDenominatorReconstruct2;
        }
    }
}

