/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.context.ranking;

import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Set;
import java.util.stream.Collectors;
import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Maths;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.DenseVector;
import net.librec.math.structure.SequentialSparseVector;
import net.librec.math.structure.TensorEntry;
import net.librec.math.structure.Vector;
import net.librec.math.structure.VectorBasedDenseVector;
import net.librec.math.structure.VectorBasedSequentialSparseVector;
import net.librec.recommender.FactorizationMachineRecommender;

@ModelData(value={"isRanking", "lambdafm", "userFactors", "itemFactors"})
public class DLambdaFMRecommender
extends FactorizationMachineRecommender {
    public static double max = 2.147483647E9;
    private double rho;
    private int lossf;
    private double lRate;
    private HashMap<Integer, Integer> itemFeatureMapping;

    @Override
    protected void setup() throws LibrecException {
        int featureIdx;
        int itemIdx;
        int[] entryKeys;
        super.setup();
        this.rho = this.conf.getDouble("rec.recommender.rho", 0.1);
        this.lossf = this.conf.getInt("rec.recommender.lossf", 1);
        this.lRate = this.conf.getDouble("rec.iterator.learnRate", 0.1);
        this.itemFeatureMapping = new HashMap();
        for (TensorEntry te : this.trainTensor) {
            entryKeys = te.keys();
            itemIdx = entryKeys[1];
            featureIdx = entryKeys[2];
            if (this.itemFeatureMapping.containsKey(itemIdx)) continue;
            this.itemFeatureMapping.put(itemIdx, featureIdx);
        }
        for (TensorEntry te : this.testTensor) {
            entryKeys = te.keys();
            itemIdx = entryKeys[1];
            featureIdx = entryKeys[2];
            if (this.itemFeatureMapping.containsKey(itemIdx)) continue;
            this.itemFeatureMapping.put(itemIdx, featureIdx);
        }
    }

    /*
     * Could not resolve type clashes
     */
    @Override
    protected void trainModel() throws LibrecException {
        VectorBasedDenseVector grad = new VectorBasedDenseVector(this.p);
        VectorBasedDenseVector grad_visited = new VectorBasedDenseVector(this.p);
        for (int iter = 1; iter <= this.numIterations; ++iter) {
            this.loss = 0.0;
            int smax = this.numUsers * 300;
            for (int s = 0; s < smax; ++s) {
                Set u_items;
                SequentialSparseVector pu;
                int u = 0;
                int i = 0;
                int j = 0;
                do {
                    u = Randoms.uniform(this.numUsers);
                    pu = this.trainMatrix.row(u);
                    u_items = Arrays.stream(pu.getIndices()).boxed().collect(Collectors.toSet());
                } while (pu.getNumEntries() == 0);
                int[] is = pu.getIndices();
                i = is[Randoms.uniform(is.length)];
                do {
                    try {
                        j = this.ChooseNeg(10, u, pu);
                    }
                    catch (Exception e) {
                        e.printStackTrace();
                    }
                } while (u_items.contains(j));
                int feature_i = this.itemFeatureMapping.get(i);
                int feature_j = this.itemFeatureMapping.get(j);
                VectorBasedSequentialSparseVector x_i = this.tenserKeysToFeatureVector(new int[]{u, i, feature_i});
                VectorBasedSequentialSparseVector x_j = this.tenserKeysToFeatureVector(new int[]{u, j, feature_j});
                int[] i_index_List = ((SequentialSparseVector)x_i).getIndices();
                int[] j_index_List = ((SequentialSparseVector)x_j).getIndices();
                double si = 1.0;
                double sj = 0.0;
                double sij = si - sj;
                double xui = 0.0;
                double xuj = 0.0;
                try {
                    xui = this.predict(x_i);
                    xuj = this.predict(x_j);
                }
                catch (Exception e) {
                    e.printStackTrace();
                }
                DenseVector sum_pos = this.sum(x_i);
                DenseVector sum_neg = this.sum(x_j);
                double xuij = xui - xuj;
                double Sij = sij > 0.0 ? 1.0 : (double)(sij == 0.0 ? 0 : -1);
                double pij_real = 0.5 * (1.0 + Sij);
                double pij = Maths.logistic(xuij);
                double cmg = this.getGradMag(this.lossf, xuij);
                this.loss += -pij_real * Math.log(pij) - (1.0 - pij_real) * Math.log(1.0 - pij);
                for (int idx : i_index_List) {
                    grad.set(idx, 0.0);
                    grad_visited.set(idx, 0.0);
                }
                for (int idx : j_index_List) {
                    grad.set(idx, 0.0);
                    grad_visited.set(idx, 0.0);
                }
                Object object = x_i.iterator();
                while (object.hasNext()) {
                    Vector.VectorEntry ve = (Vector.VectorEntry)object.next();
                    grad.plus(ve.index(), ve.get());
                }
                for (Vector.VectorEntry ve : x_j) {
                    grad.plus(ve.index(), -ve.get());
                }
                for (int idx : (Object)i_index_List) {
                    if (grad_visited.get(idx) != 0.0) continue;
                    this.W.plus(idx, this.lRate * (cmg * grad.get(idx) - (double)this.regW * this.W.get(idx)));
                    grad_visited.set(idx, 1.0);
                }
                for (int idx : j_index_List) {
                    if (grad_visited.get(idx) != 0.0) continue;
                    this.W.plus(idx, this.lRate * (cmg * grad.get(idx) - (double)this.regW * this.W.get(idx)));
                    grad_visited.set(idx, 1.0);
                }
                for (int f = 0; f < this.numFactors; ++f) {
                    int idx;
                    for (int idx2 : i_index_List) {
                        grad.set(idx2, 0.0);
                        grad_visited.set(idx2, 0.0);
                    }
                    for (int idx2 : j_index_List) {
                        grad.set(idx2, 0.0);
                        grad_visited.set(idx2, 0.0);
                    }
                    Object object2 = x_i.iterator();
                    while (object2.hasNext()) {
                        Vector.VectorEntry ve = (Vector.VectorEntry)object2.next();
                        idx = ve.index();
                        double value = ve.get();
                        grad.plus(idx, sum_pos.get(f) * value - this.V.get(idx, f) * value * value);
                    }
                    for (Vector.VectorEntry ve : x_j) {
                        idx = ve.index();
                        double value = ve.get();
                        grad.plus(idx, -(sum_neg.get(f) * value - this.V.get(idx, f) * value * value));
                    }
                    for (int idx3 : (Object)i_index_List) {
                        if (grad_visited.get(idx3) != 0.0) continue;
                        this.V.plus(idx3, f, this.lRate * (cmg * grad.get(idx3) - (double)this.regF * this.V.get(idx3, f)));
                        grad_visited.set(idx3, 1.0);
                    }
                    for (int idx4 : j_index_List) {
                        if (grad_visited.get(idx4) != 0.0) continue;
                        this.V.plus(idx4, f, this.lRate * (cmg * grad.get(idx4) - (double)this.regF * this.V.get(idx4, f)));
                        grad_visited.set(idx4, 1.0);
                    }
                }
            }
            if (this.isConverged(iter)) break;
            this.lastLoss = this.loss;
        }
    }

    private DenseVector sum(SequentialSparseVector x) {
        VectorBasedDenseVector sum2 = new VectorBasedDenseVector(this.numFactors);
        for (int f = 0; f < this.numFactors; ++f) {
            double sum_f = 0.0;
            sum2.set(f, 0.0);
            for (Vector.VectorEntry ve : x) {
                int idx = ve.index();
                double d = this.V.get(idx, f) * ve.get();
                sum2.set(f, sum_f += d);
            }
        }
        return sum2;
    }

    private int ChooseNeg(int size, int u, SequentialSparseVector pu) throws Exception {
        Set u_items = Arrays.stream(pu.getIndices()).boxed().collect(Collectors.toSet());
        if (size > this.numItems) {
            throw new IllegalArgumentException();
        }
        final double[] RankingPro = new double[this.numItems];
        Arrays.fill(RankingPro, -100.0);
        for (int i = 0; i < size; ++i) {
            int j = 0;
            while (u_items.contains(j = Randoms.uniform(this.numItems))) {
            }
            int feature_j = this.itemFeatureMapping.get(j);
            int[] featureKeys = new int[]{u, j, feature_j};
            RankingPro[j] = this.predict(featureKeys);
        }
        Integer[] iidRank = new Integer[this.numItems];
        for (int i = 0; i < this.numItems; ++i) {
            iidRank[i] = i;
        }
        Arrays.sort(iidRank, new Comparator<Integer>(){

            @Override
            public int compare(Integer o1, Integer o2) {
                return RankingPro[o1] > RankingPro[o2] ? -1 : (RankingPro[o1] < RankingPro[o2] ? 1 : 0);
            }
        });
        double sum2 = 0.0;
        double[] iidRelativeRank = new double[this.numItems];
        for (int i = 0; i < size; ++i) {
            int iid = iidRank[i];
            iidRelativeRank[iid] = Math.exp((double)(-(i + 1)) / ((double)size * this.rho));
            sum2 += iidRelativeRank[iid];
        }
        HashMap<Integer, Integer> map = new HashMap<Integer, Integer>();
        double[] iidRelativeRank_small = new double[size];
        int k = 0;
        for (int i = 0; i < iidRelativeRank.length; ++i) {
            if (iidRelativeRank[i] == 0.0) continue;
            iidRelativeRank[i] = iidRelativeRank[i] / sum2;
            iidRelativeRank_small[k] = iidRelativeRank[i];
            map.put(k, i);
            ++k;
        }
        int index = Randoms.discrete(iidRelativeRank_small);
        return (Integer)map.get(index);
    }

    protected double getGradMag(int losstype, double xuij) {
        double z = 1.0;
        double cmg = 0.0;
        switch (losstype) {
            case 0: {
                if (!(z * xuij <= 1.0)) break;
                cmg = z;
                break;
            }
            case 1: {
                cmg = z * xuij <= 0.0 ? -z : (z * xuij <= 1.0 ? (1.0 - z * xuij) * -z : 0.0);
                cmg = -cmg;
                break;
            }
            case 2: {
                cmg = Maths.logistic(-xuij);
                break;
            }
            case 3: {
                cmg = Math.sqrt(Maths.logistic(xuij)) / (1.0 + Math.exp(xuij));
                break;
            }
            case 4: {
                cmg = Math.exp(-xuij);
                break;
            }
            case 5: {
                if (!(xuij <= 1.0)) break;
                cmg = 0.5 * (1.0 - xuij);
                break;
            }
        }
        return cmg;
    }

    @Override
    protected double predict(int userIdx, int itemIdx) throws LibrecException {
        int featureIdx = this.itemFeatureMapping.get(itemIdx);
        return this.predict(new int[]{userIdx, itemIdx, featureIdx});
    }
}

