/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.cf.ranking;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.DenseVector;
import net.librec.math.structure.SequentialAccessSparseMatrix;
import net.librec.math.structure.SequentialSparseVector;
import net.librec.math.structure.Vector;
import net.librec.math.structure.VectorBasedDenseVector;
import net.librec.recommender.MatrixFactorizationRecommender;

@ModelData(value={"isRanking", "rankals", "userFactors", "itemFactors", "trainMatrix"})
public class RankALSRecommender
extends MatrixFactorizationRecommender {
    private boolean isSupportWeight;
    private VectorBasedDenseVector supportVector;
    private double sumSupport;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        this.isSupportWeight = this.conf.getBoolean("rec.rankals.support.weight", true);
        this.supportVector = new VectorBasedDenseVector(this.numItems);
        this.sumSupport = 0.0;
        for (int itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
            double supportValue = this.isSupportWeight ? (double)this.trainMatrix.column(itemIdx).getNumEntries() : 1.0;
            this.supportVector.set(itemIdx, supportValue);
            this.sumSupport += supportValue;
        }
    }

    @Override
    protected void trainModel() throws LibrecException {
        for (int iter = 1; iter < this.numIterations; ++iter) {
            Object sum_sqr;
            System.out.println("Train iteration " + iter);
            DenseVector sum_sq = new VectorBasedDenseVector(this.numFactors);
            DenseMatrix sum_sqq = new DenseMatrix(this.numFactors, this.numFactors);
            for (int j = 0; j < this.numItems; ++j) {
                DenseVector qj = this.itemFactors.row(j);
                double sj = this.supportVector.get(j);
                sum_sq = sum_sq.plus(qj.times(sj));
                sum_sqq = sum_sqq.plus(qj.outer(qj).times(sj));
            }
            List<Integer> cus = this.nonEmptyRows(this.trainMatrix);
            double user_loss = 0.0;
            for (int u : cus) {
                DenseMatrix sum_cqq = new DenseMatrix(this.numFactors, this.numFactors);
                DenseVector sum_cq = new VectorBasedDenseVector(this.numFactors);
                DenseVector sum_cqr = new VectorBasedDenseVector(this.numFactors);
                sum_sqr = new VectorBasedDenseVector(this.numFactors);
                SequentialSparseVector Ru = this.trainMatrix.row(u);
                double sum_c = Ru.getNumEntries();
                double sum_sr = 0.0;
                double sum_cr = 0.0;
                for (Vector.VectorEntry ve : Ru) {
                    int i = ve.index();
                    double rui = ve.get();
                    DenseVector qi = this.itemFactors.row(i);
                    sum_cqq = sum_cqq.plus(qi.outer(qi));
                    sum_cq = sum_cq.plus(qi);
                    sum_cqr = sum_cqr.plus(qi.times(rui));
                    double si = this.supportVector.get(i);
                    sum_sr += si * rui;
                    sum_cr += rui;
                    sum_sqr = ((DenseVector)sum_sqr).plus(qi.times(si * rui));
                }
                DenseMatrix M = sum_cqq.times(this.sumSupport).minus(sum_cq.outer(sum_sq)).minus(sum_sq.outer(sum_cq)).plus(sum_sqq.times(sum_c));
                DenseVector y = sum_cqr.times(this.sumSupport).minus(sum_cq.times(sum_sr)).minus(sum_sq.times(sum_cr)).plus(((DenseVector)sum_sqr).times(sum_c));
                DenseVector pu = M.inverse().times(y);
                user_loss += y.getLengthSquared();
                this.userFactors.row(u).assign((index, value) -> pu.get(index));
            }
            String info = "RankALS iter " + iter + ": sq. user loss = " + user_loss;
            this.LOG.info(info);
            HashMap<Integer, Double> m_sum_sr = new HashMap<Integer, Double>();
            HashMap<Integer, Double> m_sum_cr = new HashMap<Integer, Double>();
            HashMap<Integer, Double> m_sum_c = new HashMap<Integer, Double>();
            HashMap<Integer, VectorBasedDenseVector> m_sum_cq = new HashMap<Integer, VectorBasedDenseVector>();
            sum_sqr = cus.iterator();
            while (sum_sqr.hasNext()) {
                int u = (Integer)sum_sqr.next();
                SequentialSparseVector Ru = this.trainMatrix.row(u);
                double sum_sr = 0.0;
                double sum_cr = 0.0;
                double sum_c = Ru.getNumEntries();
                DenseVector sum_cq = new VectorBasedDenseVector(this.numFactors);
                for (Object ve : Ru) {
                    int j = ve.index();
                    double ruj = ve.get();
                    double sj = this.supportVector.get(j);
                    sum_sr += sj * ruj;
                    sum_cr += ruj;
                    sum_cq = sum_cq.plus(this.itemFactors.row(j));
                }
                m_sum_sr.put(u, sum_sr);
                m_sum_cr.put(u, sum_cr);
                m_sum_c.put(u, sum_c);
                m_sum_cq.put(u, (VectorBasedDenseVector)sum_cq);
            }
            for (int i = 0; i < this.numItems; ++i) {
                Object ve;
                DenseMatrix sum_cpp = new DenseMatrix(this.numFactors, this.numFactors);
                DenseMatrix sum_p_p_c = new DenseMatrix(this.numFactors, this.numFactors);
                DenseVector sum_p_p_cq = new VectorBasedDenseVector(this.numFactors);
                DenseVector sum_cpr = new VectorBasedDenseVector(this.numFactors);
                DenseVector sum_c_sr_p = new VectorBasedDenseVector(this.numFactors);
                DenseVector sum_cr_p = new VectorBasedDenseVector(this.numFactors);
                DenseVector sum_p_r_c = new VectorBasedDenseVector(this.numFactors);
                double si = this.supportVector.get(i);
                VectorBasedDenseVector itemVector = new VectorBasedDenseVector(this.trainMatrix.column(i));
                ve = cus.iterator();
                while (ve.hasNext()) {
                    int u = (Integer)ve.next();
                    DenseVector pu = this.userFactors.row(u);
                    double rui = itemVector.get(u);
                    DenseMatrix pp = pu.outer(pu);
                    sum_cpp = sum_cpp.plus(pp);
                    sum_p_p_cq = sum_p_p_cq.plus(pp.times((Vector)m_sum_cq.get(u)));
                    sum_p_p_c = sum_p_p_c.plus(pp.times((Double)m_sum_c.get(u)));
                    sum_cr_p = sum_cr_p.plus(pu.times((Double)m_sum_cr.get(u)));
                    if (!(rui > 0.0)) continue;
                    sum_cpr = sum_cpr.plus(pu.times(rui));
                    sum_c_sr_p = sum_c_sr_p.plus(pu.times((Double)m_sum_sr.get(u)));
                    sum_p_r_c = sum_p_r_c.plus(pu.times(rui * (Double)m_sum_c.get(u)));
                }
                DenseMatrix subtract = sum_cpp.times(si + 1.0);
                DenseMatrix M = sum_cpp.times(this.sumSupport).plus(sum_p_p_c.times(si)).minus(subtract);
                DenseVector y = sum_cpp.times(sum_sq).plus(sum_cpr.times(this.sumSupport)).minus(sum_c_sr_p).plus(sum_p_p_cq.times(si)).minus(sum_cr_p.times(si)).plus(sum_p_r_c.times(si));
                DenseVector qi = M.inverse().times(y.minus(subtract.times(this.itemFactors.row(i))));
                this.itemFactors.row(i).assign((index, value) -> qi.get(index));
            }
        }
    }

    public List<Integer> nonEmptyRows(SequentialAccessSparseMatrix matrix) {
        if (matrix == null) {
            this.LOG.error("The matrix passed in is null.");
            return null;
        }
        ArrayList<Integer> list = new ArrayList<Integer>(matrix.rowSize());
        for (int userId = 0; userId < matrix.rowSize(); ++userId) {
            if (matrix.row(userId).getNumEntries() <= 0) continue;
            list.add(userId);
        }
        return list;
    }
}

