/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.cf.ranking;

import com.google.common.collect.BiMap;
import java.io.BufferedWriter;
import java.io.FileWriter;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import net.librec.common.LibrecException;
import net.librec.math.structure.SequentialSparseVector;
import net.librec.recommender.MatrixFactorizationRecommender;

public class PNMFRecommender
extends MatrixFactorizationRecommender {
    private static final int PARALLELIZE_USER_SPLIT_SIZE = 5000;
    private double[][] w;
    private int numFactors;
    private int numIterations;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        this.numFactors = this.conf.getInt("rec.factor.number", 15);
        this.numIterations = this.conf.getInt("rec.iterator.maximum", 100);
        this.w = new double[this.numFactors][this.numItems];
        this.initMatrix(this.w);
    }

    private void initMatrix(double[][] m) {
        double initValue = 1.0 / ((double)this.numItems * 2.0);
        Random random = new Random(123456789L);
        for (int i = 0; i < m.length; ++i) {
            for (int j = 0; j < m[i].length; ++j) {
                m[i][j] = (random.nextDouble() + 1.0) * initValue;
            }
        }
    }

    @Override
    public void trainModel() {
        int availableProcessors = Runtime.getRuntime().availableProcessors();
        this.LOG.info("availableProcessors=" + availableProcessors);
        ExecutorService executorService = Executors.newFixedThreadPool(availableProcessors);
        for (int iter = 0; iter <= this.numIterations; ++iter) {
            this.LOG.info("Starting iteration=" + iter);
            this.train(executorService, iter);
        }
        executorService.shutdown();
    }

    private void train(ExecutorService executorService, int iteration) {
        ArrayList<ParallelExecTask> tasks = new ArrayList<ParallelExecTask>(this.numUsers / 5000 + 1);
        for (int fromUser = 0; fromUser < this.numUsers; fromUser += 5000) {
            int toUserExclusive = Math.min(this.numUsers, fromUser + 5000);
            ParallelExecTask task = new ParallelExecTask(fromUser, toUserExclusive);
            tasks.add(task);
        }
        try {
            int factorIdx;
            int itemIdx;
            List results = executorService.invokeAll(tasks);
            double[][] resultNumerator = new double[this.numFactors][this.numItems];
            double[] summedLatentFactors = new double[this.numFactors];
            int[] countUsersBoughtItem = new int[this.numItems];
            double sumLog = 0.0;
            for (Future future : results) {
                int factorIdx2;
                AggResult result = (AggResult)future.get();
                for (factorIdx2 = 0; factorIdx2 < this.numFactors; ++factorIdx2) {
                    for (itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
                        double[] dArray = resultNumerator[factorIdx2];
                        int n = itemIdx;
                        dArray[n] = dArray[n] + result.resultNumerator[factorIdx2][itemIdx];
                    }
                }
                for (int itemIdx2 = 0; itemIdx2 < this.numItems; ++itemIdx2) {
                    int n = itemIdx2;
                    countUsersBoughtItem[n] = countUsersBoughtItem[n] + result.countUsersBoughtItem[itemIdx2];
                }
                for (factorIdx2 = 0; factorIdx2 < this.numFactors; ++factorIdx2) {
                    int n = factorIdx2;
                    summedLatentFactors[n] = summedLatentFactors[n] + result.summedLatentFactors[factorIdx2];
                }
                sumLog += result.sumLog;
            }
            double[] wNorm = new double[this.numFactors];
            for (factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                double sum2 = 0.0;
                for (itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
                    sum2 += this.w[factorIdx][itemIdx];
                }
                wNorm[factorIdx] = sum2;
            }
            this.printDivergence(summedLatentFactors, countUsersBoughtItem, sumLog, wNorm, iteration);
            for (factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                for (int itemIdx3 = 0; itemIdx3 < this.numItems; ++itemIdx3) {
                    double oldValue = this.w[factorIdx][itemIdx3];
                    double numerator = resultNumerator[factorIdx][itemIdx3];
                    double denominator = (double)countUsersBoughtItem[itemIdx3] * wNorm[factorIdx] + summedLatentFactors[factorIdx];
                    double newValue = oldValue * StrictMath.sqrt(numerator / denominator);
                    if (Double.isNaN(newValue)) {
                        this.LOG.warn("Double.isNaN  " + numerator + " " + denominator + " " + oldValue + " " + newValue);
                    }
                    this.w[factorIdx][itemIdx3] = newValue;
                }
            }
        }
        catch (InterruptedException | ExecutionException e) {
            this.LOG.error("", e);
            throw new IllegalStateException(e);
        }
    }

    private void printDivergence(double[] summedLatentFactors, int[] countUsersBoughtItem, double sumLog, double[] wNorm, int iteration) {
        int countAll = 0;
        for (int itemIdx = 0; itemIdx < countUsersBoughtItem.length; ++itemIdx) {
            countAll += countUsersBoughtItem[itemIdx];
        }
        double sumAllEstimate = 0.0;
        for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
            sumAllEstimate += wNorm[factorIdx] * summedLatentFactors[factorIdx];
        }
        double divergence = sumLog - (double)countAll + sumAllEstimate;
        this.LOG.info("Divergence (before iteration " + iteration + ")=" + divergence);
    }

    private double predict(SequentialSparseVector itemRatingsVector, int itemIdx) {
        double sum2 = 0.0;
        for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
            sum2 += this.w[factorIdx][itemIdx] * this.predictFactor(itemRatingsVector, factorIdx);
        }
        return sum2;
    }

    private double predictFactor(SequentialSparseVector itemRatingsVector, int factorIdx) {
        double sum2 = 0.0;
        for (int itemIdx : itemRatingsVector.getIndices()) {
            sum2 += this.w[factorIdx][itemIdx];
        }
        return sum2;
    }

    private double[] predictFactors(SequentialSparseVector itemRatingsVector) {
        double[] latentFactors = new double[this.numFactors];
        for (int itemIdx : itemRatingsVector.getIndices()) {
            for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                int n = factorIdx;
                latentFactors[n] = latentFactors[n] + this.w[factorIdx][itemIdx];
            }
        }
        return latentFactors;
    }

    @Override
    protected double predict(int userIdx, int itemIdx) throws LibrecException {
        SequentialSparseVector itemRatingsVector = this.trainMatrix.row(userIdx);
        return this.predict(itemRatingsVector, itemIdx);
    }

    @Override
    public void saveModel(String filePath) {
        this.LOG.info("Writing matrix W to file=" + filePath);
        try {
            BufferedWriter writer = new BufferedWriter(new FileWriter(filePath));
            writer.write("\"item_id\"");
            for (int i = 0; i < this.numFactors; ++i) {
                writer.write(44);
                writer.write("\"factor\"");
                writer.write(Integer.toString(i));
            }
            writer.newLine();
            BiMap items = this.itemMappingData.inverse();
            for (int itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
                writer.write(34);
                writer.write((String)items.get(itemIdx));
                writer.write(34);
                for (int factorIdx = 0; factorIdx < this.numFactors; ++factorIdx) {
                    writer.write(44);
                    writer.write(Double.toString(this.w[factorIdx][itemIdx]));
                }
                writer.newLine();
            }
            writer.close();
        }
        catch (Exception e) {
            this.LOG.error("Could not save model", e);
        }
    }

    private class ParallelExecTask
    implements Callable<AggResult> {
        private final int fromUser;
        private final int toUser;

        public ParallelExecTask(int fromUser, int toUser) {
            this.fromUser = fromUser;
            this.toUser = toUser;
        }

        @Override
        public AggResult call() throws Exception {
            double[][] resultNumerator = new double[PNMFRecommender.this.numFactors][PNMFRecommender.this.numItems];
            double[] summedLatentFactors = new double[PNMFRecommender.this.numFactors];
            int[] countUsersBoughtItem = new int[PNMFRecommender.this.numItems];
            double sumLog = 0.0;
            for (int userIdx = this.fromUser; userIdx < this.toUser; ++userIdx) {
                SequentialSparseVector itemRatingsVector = PNMFRecommender.this.trainMatrix.row(userIdx);
                if (itemRatingsVector.getNumEntries() <= 0) continue;
                double[] thisUserLatentFactors = PNMFRecommender.this.predictFactors(itemRatingsVector);
                for (int factorIdx = 0; factorIdx < summedLatentFactors.length; ++factorIdx) {
                    int n = factorIdx;
                    summedLatentFactors[n] = summedLatentFactors[n] + thisUserLatentFactors[factorIdx];
                }
                double[] second_term_numerator = new double[PNMFRecommender.this.numFactors];
                for (int itemIdx : itemRatingsVector.getIndices()) {
                    int factorIdx;
                    double estimate = 0.0;
                    for (int factorIdx2 = 0; factorIdx2 < thisUserLatentFactors.length; ++factorIdx2) {
                        estimate += thisUserLatentFactors[factorIdx2] * PNMFRecommender.this.w[factorIdx2][itemIdx];
                    }
                    double estimateFactor = 1.0 / estimate;
                    sumLog += Math.log(estimateFactor);
                    int n = itemIdx;
                    countUsersBoughtItem[n] = countUsersBoughtItem[n] + 1;
                    for (factorIdx = 0; factorIdx < thisUserLatentFactors.length; ++factorIdx) {
                        double first_term_numerator = estimateFactor * thisUserLatentFactors[factorIdx];
                        double[] dArray = resultNumerator[factorIdx];
                        int n2 = itemIdx;
                        dArray[n2] = dArray[n2] + first_term_numerator;
                    }
                    for (factorIdx = 0; factorIdx < thisUserLatentFactors.length; ++factorIdx) {
                        int n3 = factorIdx;
                        second_term_numerator[n3] = second_term_numerator[n3] + estimateFactor * PNMFRecommender.this.w[factorIdx][itemIdx];
                    }
                }
                for (int itemIdx : itemRatingsVector.getIndices()) {
                    for (int factorIdx = 0; factorIdx < second_term_numerator.length; ++factorIdx) {
                        double[] dArray = resultNumerator[factorIdx];
                        int n = itemIdx;
                        dArray[n] = dArray[n] + second_term_numerator[factorIdx];
                    }
                }
            }
            return new AggResult(resultNumerator, summedLatentFactors, countUsersBoughtItem, sumLog);
        }
    }

    private static class AggResult {
        private final double[][] resultNumerator;
        private final double[] summedLatentFactors;
        private final int[] countUsersBoughtItem;
        private final double sumLog;

        public AggResult(double[][] resultNumerator, double[] summedLatentFactors, int[] countUsersBoughtItem, double sumLog) {
            this.resultNumerator = resultNumerator;
            this.summedLatentFactors = summedLatentFactors;
            this.countUsersBoughtItem = countUsersBoughtItem;
            this.sumLog = sumLog;
        }
    }
}

