/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender;

import java.util.HashMap;
import net.librec.common.LibrecException;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.SequentialSparseVector;
import net.librec.math.structure.Vector;
import net.librec.math.structure.VectorBasedDenseVector;
import net.librec.math.structure.VectorBasedSequentialSparseVector;
import net.librec.recommender.TensorRecommender;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public abstract class FactorizationMachineRecommender
extends TensorRecommender {
    protected final Log LOG = LogFactory.getLog(this.getClass());
    protected double w0;
    protected int p;
    protected int k;
    protected int n;
    protected VectorBasedDenseVector W;
    protected DenseMatrix V;
    protected DenseMatrix Q;
    protected float regW0;
    protected float regW;
    protected float regF;
    protected int numFactors;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        for (int dim = 0; dim < this.trainTensor.numDimensions; ++dim) {
            this.p += this.trainTensor.dimensions[dim];
        }
        this.n = this.trainTensor.size();
        this.numFactors = this.k = this.conf.getInt("rec.factor.number").intValue();
        this.w0 = 0.0;
        this.W = new VectorBasedDenseVector(this.p);
        this.W.init(0.0);
        this.V = new DenseMatrix(this.p, this.k);
        this.V.init(0.0, 0.1);
        this.regW0 = this.conf.getFloat("rec.fm.regw0", Float.valueOf(0.01f)).floatValue();
        this.regW = this.conf.getFloat("rec.fm.regW", Float.valueOf(0.01f)).floatValue();
        this.regF = this.conf.getFloat("rec.fm.regF", Float.valueOf(10.0f)).floatValue();
    }

    protected double predict(SequentialSparseVector x) throws LibrecException {
        double res = 0.0;
        res += this.w0;
        for (Vector.VectorEntry ve : x) {
            double val = ve.get();
            int ind = ve.index();
            res += val * this.W.get(ind);
        }
        for (int f = 0; f < this.k; ++f) {
            double sum1 = 0.0;
            double sum2 = 0.0;
            for (Vector.VectorEntry ve : x) {
                double xi = ve.get();
                int i = ve.index();
                double vif = this.V.get(i, f);
                sum1 += vif * xi;
                sum2 += vif * vif * xi * xi;
            }
            res += (sum1 * sum1 - sum2) / 2.0;
        }
        return res;
    }

    protected double predict(VectorBasedSequentialSparseVector x, boolean bound) throws LibrecException {
        double pred = this.predict(x);
        if (bound) {
            if (pred > this.maxRate) {
                pred = this.maxRate;
            }
            if (pred < this.minRate) {
                pred = this.minRate;
            }
        }
        return pred;
    }

    private int[] getUserItemIndex(VectorBasedSequentialSparseVector x) {
        int[] inds = x.getIndices();
        int userInd = inds[0];
        int itemInd = inds[1] - this.numUsers;
        return new int[]{userInd, itemInd};
    }

    protected VectorBasedSequentialSparseVector tenserKeysToFeatureVector(int[] tenserKeys) {
        int capacity = this.p;
        HashMap<Integer, Integer> mapVector = new HashMap<Integer, Integer>();
        int colPrefix = 0;
        for (int i = 0; i < tenserKeys.length; ++i) {
            mapVector.put(colPrefix + tenserKeys[i], 1);
            colPrefix += this.trainTensor.dimensions[i];
        }
        return new VectorBasedSequentialSparseVector(capacity, mapVector);
    }

    @Override
    protected double predict(int[] keys) throws LibrecException {
        VectorBasedSequentialSparseVector featureVec = this.tenserKeysToFeatureVector(keys);
        return this.predict(featureVec);
    }
}

