/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.cf.rating;

import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Maths;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.MatrixEntry;
import net.librec.math.structure.SequentialSparseVector;
import net.librec.math.structure.Vector;
import net.librec.recommender.MatrixFactorizationRecommender;

public class IRRGRecommender
extends MatrixFactorizationRecommender {
    private double alpha;
    private double C = 50.0;
    private int K = 50;
    private ArrayList<ArrayList> mylist = new ArrayList();
    private Table<Integer, Integer, Integer> itemCount = HashBasedTable.create();
    private Table<Integer, Integer, Double> itemCorrsAR = HashBasedTable.create();
    private Table<Integer, Integer, Double> itemCorrsAR_Sorted = HashBasedTable.create();
    private Table<Integer, Integer, Double> itemCorrsAR_added = HashBasedTable.create();
    private Map<Integer, List<Number>> itemCorrsGAR = new HashMap<Integer, List<Number>>();
    private Map<Integer, Table<Integer, Integer, Double>> itemCorrsGAR_Sorted = new HashMap<Integer, Table<Integer, Integer, Double>>();

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        this.alpha = this.conf.getDouble("rec.alpha");
        this.userFactors.init(0.8);
        this.itemFactors.init(0.8);
        this.preprocess();
    }

    @Override
    protected void trainModel() throws LibrecException {
        for (int iter = 1; iter <= this.numIterations; ++iter) {
            this.loss = 0.0;
            DenseMatrix PS = new DenseMatrix(this.numUsers, this.numFactors);
            DenseMatrix QS = new DenseMatrix(this.numItems, this.numFactors);
            for (MatrixEntry me : this.trainMatrix) {
                int u = me.row();
                int j = me.column();
                double ruj = me.get();
                if (ruj <= 0.0) continue;
                double pred = this.innerPredict(u, j);
                double euj = Maths.logistic(pred) - Maths.normalize(ruj, this.minRate, this.maxRate);
                double csgd = Maths.logisticGradientValue(pred) * euj;
                this.loss += euj * euj;
                for (int f = 0; f < this.numFactors; ++f) {
                    double puf = this.userFactors.get(u, f);
                    double qjf = this.itemFactors.get(j, f);
                    PS.plus(u, f, csgd * qjf + (double)this.regUser * puf);
                    QS.plus(j, f, csgd * puf + (double)this.regItem * qjf);
                    this.loss += (double)this.regUser * puf * puf + (double)this.regItem * qjf * qjf;
                }
            }
            Iterator<Object> iterator = this.itemCorrsAR_added.columnKeySet().iterator();
            while (iterator.hasNext()) {
                int j = (Integer)iterator.next();
                for (int k : this.itemCorrsAR_added.column(j).keySet()) {
                    double skj = this.itemCorrsAR_added.get(k, j);
                    for (int f = 0; f < this.numFactors; ++f) {
                        double ekj = this.itemFactors.get(j, f) - this.itemFactors.get(k, f);
                        QS.plus(j, f, this.alpha * skj * ekj);
                        this.loss += this.alpha * skj * ekj * ekj;
                    }
                }
                for (int g : this.itemCorrsAR_added.row(j).keySet()) {
                    double sjg = this.itemCorrsAR_added.get(j, g);
                    for (int f = 0; f < this.numFactors; ++f) {
                        double ejg = this.itemFactors.get(j, f) - this.itemFactors.get(g, f);
                        QS.plus(j, f, this.alpha * sjg * ejg);
                    }
                }
            }
            iterator = this.itemCorrsGAR_Sorted.keySet().iterator();
            while (iterator.hasNext()) {
                int j = (Integer)iterator.next();
                Table<Object, Object, Object> temp = HashBasedTable.create();
                temp = this.itemCorrsGAR_Sorted.get(j);
                Iterator<Object> iterator2 = temp.rowKeySet().iterator();
                while (iterator2.hasNext()) {
                    int g = (Integer)iterator2.next();
                    Map<Object, Object> col = temp.row(g);
                    Iterator<Object> f = col.keySet().iterator();
                    while (f.hasNext()) {
                        int k = (Integer)f.next();
                        for (int f2 = 0; f2 < this.numFactors; ++f2) {
                            double egkj = this.itemFactors.get(j, f2) - (this.itemFactors.get(g, f2) + this.itemFactors.get(k, f2)) / Math.sqrt(2.0);
                            double egkj_1 = this.alpha * (Double)col.get(k) * egkj;
                            QS.plus(j, f2, egkj_1);
                            this.loss += egkj_1 * egkj;
                        }
                    }
                }
                iterator2 = this.itemCorrsGAR_Sorted.keySet().iterator();
                while (iterator2.hasNext()) {
                    int k = (Integer)iterator2.next();
                    if (k == j) continue;
                    Table<Object, Object, Object> temp1 = HashBasedTable.create();
                    temp1 = this.itemCorrsGAR_Sorted.get(k);
                    Map<Object, Object> row = temp1.row(j);
                    Iterator<Object> iterator3 = row.keySet().iterator();
                    while (iterator3.hasNext()) {
                        int g = (Integer)iterator3.next();
                        for (int f = 0; f < this.numFactors; ++f) {
                            double ejgk = this.itemFactors.get(k, f) - (this.itemFactors.get(j, f) + this.itemFactors.get(g, f)) / Math.sqrt(2.0);
                            double ejgk_1 = -this.alpha * (Double)row.get(g) * ejgk / Math.sqrt(2.0);
                            QS.plus(j, f, ejgk_1);
                        }
                    }
                }
            }
            this.userFactors = this.userFactors.plus(PS.times(-this.learnRate));
            this.itemFactors = this.itemFactors.plus(QS.times(-this.learnRate));
            this.loss *= 0.5;
            if (this.isConverged(iter) && this.earlyStop) break;
            this.updateLRate(iter);
        }
    }

    protected double innerPredict(int userIdx, int itemIdx) throws LibrecException {
        double pred = this.userFactors.row(userIdx).dot(this.itemFactors.row(itemIdx));
        return pred;
    }

    @Override
    protected double predict(int userIdx, int itemIdx) throws LibrecException {
        double pred = this.userFactors.row(userIdx).dot(this.itemFactors.row(itemIdx));
        pred = Maths.logistic(pred);
        pred = this.minRate + pred * (this.maxRate - this.minRate);
        return pred;
    }

    protected void computeAR() {
        for (int x = 0; x < this.numItems; ++x) {
            SequentialSparseVector qx = this.trainMatrix.column(x);
            int[] x_users = qx.getIndices();
            int total = qx.getNumEntries();
            for (int y = 0; y < this.numItems; ++y) {
                if (x == y) continue;
                int count = 0;
                int[] y_users = this.trainMatrix.column(y).getIndices();
                for (int u : x_users) {
                    for (int y_user : y_users) {
                        if (u != y_user) continue;
                        ++count;
                    }
                }
                double shrink = (double)count / ((double)count + this.C);
                double conf = shrink * ((double)count + 0.0) / (double)total;
                if (!(conf > 0.0)) continue;
                this.itemCorrsAR.put(x, y, conf);
                this.itemCount.put(x, y, count);
            }
        }
    }

    protected void storeAR(int size, double[][] temp) {
        for (int i = 0; i < size; ++i) {
            int x_id = (int)temp[i][0];
            int y_id = (int)temp[i][1];
            this.itemCorrsAR_Sorted.put(x_id, y_id, temp[i][2]);
        }
    }

    protected void sortAR() {
        for (int x : this.itemCorrsAR.columnKeySet()) {
            int size = this.itemCorrsAR.column(x).size();
            double[][] temp = new double[size][3];
            int flag = 0;
            for (int y : this.itemCorrsAR.column(x).keySet()) {
                temp[flag][0] = y;
                temp[flag][1] = x;
                temp[flag][2] = this.itemCorrsAR.get(y, x);
                ++flag;
            }
            if (size > this.K) {
                for (int i = 0; i < this.K; ++i) {
                    for (int j = i + 1; j < size; ++j) {
                        if (!(temp[i][2] < temp[j][2])) continue;
                        for (int k = 0; k < 3; ++k) {
                            double trans = temp[i][k];
                            temp[i][k] = temp[j][k];
                            temp[j][k] = trans;
                        }
                    }
                }
                this.storeAR(this.K, temp);
                continue;
            }
            this.storeAR(size, temp);
        }
    }

    protected void computeItemset() {
        for (int y : this.itemCorrsAR.columnKeySet()) {
            Object[] x = this.itemCorrsAR_Sorted.column(y).keySet().toArray();
            for (int i = 0; i < x.length - 1; ++i) {
                for (int j = i + 1; j < x.length; ++j) {
                    if (!this.itemCount.contains(x[i], x[j])) continue;
                    ArrayList<Object> list = new ArrayList<Object>(3);
                    list.add(y);
                    list.add(x[i]);
                    list.add(x[j]);
                    this.mylist.add(list);
                }
            }
        }
    }

    protected void compute(int a, int b, int c, int count) {
        double shrink = (double)count / ((double)count + this.C);
        int co_bc = this.itemCount.get(b, c);
        double conf = shrink * ((double)count + 0.0) / (double)co_bc;
        if (this.itemCorrsGAR.containsKey(a)) {
            this.itemCorrsGAR.get(a).add(b);
            this.itemCorrsGAR.get(a).add(c);
            this.itemCorrsGAR.get(a).add(conf);
        } else {
            ArrayList<Number> list = new ArrayList<Number>();
            list.add(b);
            list.add(c);
            list.add(conf);
            this.itemCorrsGAR.put(a, list);
        }
    }

    protected void computeGAR() {
        for (int i = 0; i < this.mylist.size(); ++i) {
            int a = (Integer)this.mylist.get(i).get(0);
            int b = (Integer)this.mylist.get(i).get(1);
            int c = (Integer)this.mylist.get(i).get(2);
            SequentialSparseVector qx = this.trainMatrix.column(a);
            int count = 0;
            for (Vector.VectorEntry ve : qx) {
                int u = ve.index();
                int[] u_items = this.trainMatrix.row(u).getIndices();
                boolean rub = false;
                boolean ruc = false;
                for (int u_item : u_items) {
                    if (u_item == b) {
                        rub = true;
                    }
                    if (u_item != c) continue;
                    ruc = true;
                }
                if (!rub || !ruc) continue;
                ++count;
            }
            if (count <= 0) continue;
            this.compute(a, b, c, count);
        }
    }

    protected void storeGAR(int a, int size, List list) {
        for (int i = 0; i < size; i += 3) {
            int b = (Integer)list.get(i);
            int c = (Integer)list.get(i + 1);
            double conf = (Double)list.get(i + 2);
            if (this.itemCorrsGAR_Sorted.containsKey(a)) {
                this.itemCorrsGAR_Sorted.get(a).put(b, c, conf);
                continue;
            }
            HashBasedTable<Integer, Integer, Double> temp = HashBasedTable.create();
            temp.put(b, c, conf);
            this.itemCorrsGAR_Sorted.put(a, temp);
        }
    }

    protected void sortGAR() {
        for (int a : this.itemCorrsGAR.keySet()) {
            List<Number> list = this.itemCorrsGAR.get(a);
            if (list.size() / 3 > this.K) {
                for (int i = 0; i < 3 * this.K; i += 3) {
                    for (int j = i + 3; j < list.size(); j += 3) {
                        double conf2;
                        double conf1 = (Double)list.get(i + 2);
                        if (!(conf1 < (conf2 = ((Double)list.get(j + 2)).doubleValue()))) continue;
                        for (int x = 0; x < 2; ++x) {
                            int temp = (Integer)list.get(i + x);
                            list.set(i + x, list.get(j + x));
                            list.set(j + x, temp);
                        }
                        list.set(i + 2, conf2);
                        list.set(j + 2, conf1);
                    }
                }
                this.storeGAR(a, 3 * this.K, list);
                continue;
            }
            this.storeGAR(a, list.size(), list);
        }
    }

    protected void storeCAR(int j) {
        for (int id : this.itemCorrsAR_Sorted.column(j).keySet()) {
            double value = this.itemCorrsAR_Sorted.get(id, j);
            this.itemCorrsAR_added.put(id, j, value);
        }
    }

    protected void addAR() {
        for (int j = 0; j < this.numItems; ++j) {
            Table<Object, Object, Object> temp = HashBasedTable.create();
            temp = this.itemCorrsGAR_Sorted.get(j);
            if (temp != null) {
                int group_size = temp.size();
                if (group_size >= this.K) continue;
                int add_size = this.K - group_size;
                int item_size = this.itemCorrsAR_Sorted.column(j).size();
                double[][] trans = new double[item_size][2];
                if (item_size > add_size) {
                    int x;
                    double value;
                    int count = 0;
                    for (int id : this.itemCorrsAR_Sorted.column(j).keySet()) {
                        value = this.itemCorrsAR_Sorted.get(id, j);
                        trans[count][0] = id;
                        trans[count][1] = value;
                        ++count;
                    }
                    for (x = 0; x < add_size; ++x) {
                        for (int y = x + 1; y < trans.length; ++y) {
                            double x_value = trans[x][1];
                            double y_value = trans[y][1];
                            if (!(x_value < y_value)) continue;
                            for (int z = 0; z < 2; ++z) {
                                double tran = trans[x][z];
                                trans[x][z] = trans[y][z];
                                trans[y][z] = tran;
                            }
                        }
                    }
                    for (x = 0; x < add_size; ++x) {
                        int id;
                        id = (int)trans[x][0];
                        value = trans[x][1];
                        this.itemCorrsAR_added.put(id, j, value);
                    }
                    continue;
                }
                this.storeCAR(j);
                continue;
            }
            this.storeCAR(j);
        }
    }

    protected void preprocess() {
        this.computeAR();
        this.sortAR();
        this.computeItemset();
        this.computeGAR();
        this.sortGAR();
        this.addAR();
    }
}

