/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.cf.ranking;

import com.google.common.collect.BiMap;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.SequentialAccessSparseMatrix;
import net.librec.math.structure.SequentialSparseVector;
import net.librec.math.structure.SymmMatrix;
import net.librec.math.structure.Vector;
import net.librec.recommender.MatrixFactorizationRecommender;
import net.librec.util.Lists;

@ModelData(value={"isRanking", "slim", "coefficientMatrix", "trainMatrix", "similarityMatrix", "knn"})
public class BLNSLIMFastRecommender
extends MatrixFactorizationRecommender {
    protected int numIterations;
    private DenseMatrix coefficientMatrix;
    private Set<Integer>[] itemNNs;
    private float regL1Norm;
    private float regL2Norm;
    private float lambda3;
    private int[] groupMembershipVector;
    protected SequentialAccessSparseMatrix itemFeatureMatrix;
    private String protectedAttribute;
    BiMap<String, Integer> featureIdMapping;
    private double balance;
    private double weights;
    protected static int knn;
    private SymmMatrix similarityMatrix;
    private Set<Integer> allItems;
    private float minSimThresh;

    @Override
    protected void setup() throws LibrecException {
        int itemIdx;
        super.setup();
        knn = this.conf.getInt("rec.neighbors.knn.number", 50);
        this.numIterations = this.conf.getInt("rec.iterator.maximum");
        this.regL1Norm = this.conf.getFloat("rec.slim.regularization.l1", Float.valueOf(1.0f)).floatValue();
        this.regL2Norm = this.conf.getFloat("rec.slim.regularization.l2", Float.valueOf(1.0f)).floatValue();
        this.lambda3 = this.conf.getFloat("rec.bnslim.regularization.l3", Float.valueOf(1.0f)).floatValue();
        this.minSimThresh = this.conf.getFloat("rec.bnslim.minsimilarity", Float.valueOf(-1.0f)).floatValue();
        this.protectedAttribute = this.conf.get("data.protected.feature");
        System.out.println("***");
        System.out.println("l1 reg: " + this.regL1Norm);
        System.out.println("l2 reg: " + this.regL2Norm);
        System.out.println("balance controller l3: " + this.lambda3);
        System.out.println("***");
        this.coefficientMatrix = new DenseMatrix(this.numItems, this.numItems);
        this.coefficientMatrix.init();
        this.similarityMatrix = this.context.getSimilarity().getSimilarityMatrix();
        for (itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
            this.coefficientMatrix.set(itemIdx, itemIdx, 0.0);
        }
        this.createItemNNs();
        this.itemFeatureMatrix = this.getDataModel().getFeatureAppender().getItemFeatures();
        this.featureIdMapping = this.getDataModel().getFeatureAppender().getItemFeatureMap();
        this.groupMembershipVector = new int[this.numItems];
        for (itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
            int itemMembership = -1;
            if (this.itemFeatureMatrix.row(itemIdx).size() > 0 && this.itemFeatureMatrix.get(itemIdx, (Integer)this.featureIdMapping.get(this.protectedAttribute)) == 1.0) {
                itemMembership = 1;
            }
            this.groupMembershipVector[itemIdx] = itemMembership;
        }
    }

    @Override
    protected void trainModel() throws LibrecException {
        for (int iter = 1; iter <= this.numIterations; ++iter) {
            this.loss = 0.0;
            this.weights = 0.0;
            for (int itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
                Set<Integer> nearestNeighborCollection = knn > 0 ? this.itemNNs[itemIdx] : this.allItems;
                double[] userRatingEntries = new double[this.numUsers];
                SequentialSparseVector itemRatingVec = this.trainMatrix.column(itemIdx);
                for (Vector.VectorEntry ve : itemRatingVec) {
                    userRatingEntries[ve.index()] = ve.get();
                }
                for (Integer nearestNeighborItemIdx : nearestNeighborCollection) {
                    double sim = this.similarityMatrix.get(nearestNeighborItemIdx, itemIdx);
                    if (nearestNeighborItemIdx == itemIdx || !(sim > (double)this.minSimThresh)) continue;
                    double gradSum = 0.0;
                    double rateSum = 0.0;
                    double errors = 0.0;
                    double itemBalanceSumSqr = 0.0;
                    double itemBalanceSum = 0.0;
                    SequentialSparseVector nnUserRatingVec = this.trainMatrix.column(nearestNeighborItemIdx);
                    if (nnUserRatingVec.size() == 0) continue;
                    int nnCount = 0;
                    for (Vector.VectorEntry nnUserVectorEntry : nnUserRatingVec) {
                        int nnUserIdx = nnUserVectorEntry.index();
                        double nnRating = nnUserVectorEntry.get();
                        double rating = userRatingEntries[nnUserIdx];
                        double error = rating - this.predictFast(nnUserIdx, itemIdx, nearestNeighborItemIdx);
                        double itemBalance = this.balance;
                        itemBalanceSumSqr += itemBalance * itemBalance;
                        itemBalanceSum += itemBalance;
                        gradSum += nnRating * error;
                        rateSum += nnRating * nnRating;
                        errors += error * error;
                        ++nnCount;
                    }
                    itemBalanceSumSqr /= (double)nnCount;
                    itemBalanceSum /= (double)nnCount;
                    gradSum /= (double)nnCount;
                    rateSum /= (double)nnCount;
                    double coefficient = this.coefficientMatrix.get(nearestNeighborItemIdx, itemIdx);
                    Integer itemMembership = this.groupMembershipVector[itemIdx];
                    this.loss += 0.5 * (errors /= (double)nnCount) + 0.5 * (double)this.regL2Norm * coefficient * coefficient + (double)this.regL1Norm * coefficient + 0.5 * (double)this.lambda3 * itemBalanceSumSqr;
                    this.weights += itemBalanceSum;
                    double beta = gradSum + (double)(this.lambda3 * (float)itemMembership.intValue()) * itemBalanceSum;
                    double update = 0.0;
                    if ((double)this.regL1Norm < Math.abs(beta)) {
                        update = beta > 0.0 ? (beta - (double)this.regL1Norm) / ((double)this.regL2Norm + rateSum + (double)this.lambda3) : (beta + (double)this.regL1Norm) / ((double)this.regL2Norm + rateSum + (double)this.lambda3);
                    }
                    this.coefficientMatrix.set(nearestNeighborItemIdx, itemIdx, update);
                }
            }
            if (this.isConverged(iter) && this.earlyStop) break;
        }
    }

    protected double predict(int userIdx, int itemIdx, int excludedItemIdx) {
        double predictRating = 0.0;
        SequentialSparseVector userRatingVec = this.trainMatrix.row(userIdx);
        for (Vector.VectorEntry itemEntry : userRatingVec) {
            int nearestNeighborItemIdx = itemEntry.index();
            double nearestNeighborPredictRating = itemEntry.get();
            if (!this.itemNNs[itemIdx].contains(nearestNeighborItemIdx) || nearestNeighborItemIdx == excludedItemIdx) continue;
            predictRating += nearestNeighborPredictRating * this.coefficientMatrix.get(nearestNeighborItemIdx, itemIdx);
        }
        return predictRating;
    }

    protected double predictFast(int userIdx, int itemIdx, int excludedItemIdx) {
        double predictRating = 0.0;
        this.balance = 0.0;
        SequentialSparseVector userRatingVec = this.trainMatrix.row(userIdx);
        for (Vector.VectorEntry itemEntry : userRatingVec) {
            int nearestNeighborItemIdx = itemEntry.index();
            double nearestNeighborPredictRating = itemEntry.get();
            if (!this.itemNNs[itemIdx].contains(nearestNeighborItemIdx) || nearestNeighborItemIdx == excludedItemIdx) continue;
            double coeff = this.coefficientMatrix.get(nearestNeighborItemIdx, itemIdx);
            predictRating += nearestNeighborPredictRating * coeff;
            this.balance += (double)this.groupMembershipVector[nearestNeighborItemIdx] * coeff;
        }
        return predictRating;
    }

    @Override
    protected boolean isConverged(int iter) {
        double delta_loss = this.lastLoss - this.loss;
        this.lastLoss = this.loss;
        if (verbose) {
            String recName = this.getClass().getSimpleName().toString();
            String info = recName + " iter " + iter + ": loss = " + this.loss + ", delta_loss = " + delta_loss;
            this.LOG.info(info);
            this.LOG.info("The item balance sum is " + this.weights + "\n");
        }
        return iter > 1 ? delta_loss < 1.0E-5 : false;
    }

    @Override
    protected double predict(int userIdx, int itemIdx) throws LibrecException {
        if (null == this.itemNNs || this.itemNNs.length <= 0) {
            this.createItemNNs();
        }
        return this.predictFast(userIdx, itemIdx, -1);
    }

    public void createItemNNs() {
        this.itemNNs = new HashSet[this.numItems];
        if (knn > 0) {
            for (int itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
                Map<Integer, Double> similarityVector = this.similarityMatrix.row(itemIdx);
                int vecSize = similarityVector.size();
                if (knn < vecSize) {
                    List tempItemSimList = new ArrayList(similarityVector.size() + 1);
                    for (Map.Entry<Integer, Double> entry : similarityVector.entrySet()) {
                        tempItemSimList.add(new AbstractMap.SimpleImmutableEntry<Integer, Double>(entry.getKey(), entry.getValue()));
                    }
                    tempItemSimList = Lists.sortListTopK(tempItemSimList, true, knn);
                    this.itemNNs[itemIdx] = new HashSet<Integer>((int)((double)tempItemSimList.size() / 0.5));
                    for (Map.Entry<Integer, Double> entry : tempItemSimList) {
                        this.itemNNs[itemIdx].add(entry.getKey());
                    }
                    continue;
                }
                if (vecSize > 0) {
                    Set<Integer> simSet = similarityVector.keySet();
                    this.itemNNs[itemIdx] = new HashSet<Integer>(simSet);
                    continue;
                }
                this.itemNNs[itemIdx] = new HashSet<Integer>();
            }
        } else {
            this.allItems = this.userMappingData.values();
        }
    }
}

