/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.poi;

import com.google.common.collect.HashBasedTable;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import net.librec.common.LibrecException;
import net.librec.data.convertor.appender.LocationDataAppender;
import net.librec.math.algorithm.Maths;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.DenseVector;
import net.librec.math.structure.MatrixEntry;
import net.librec.math.structure.SequentialAccessSparseMatrix;
import net.librec.math.structure.Vector;
import net.librec.recommender.MatrixFactorizationRecommender;
import net.librec.recommender.item.KeyValue;
import net.librec.util.Lists;
import org.apache.commons.lang.ArrayUtils;

public class RankGeoFMRecommender
extends MatrixFactorizationRecommender {
    protected DenseMatrix userFactors;
    protected DenseMatrix geoUserFactors;
    protected DenseMatrix poiFactors;
    protected int numPois;
    protected SequentialAccessSparseMatrix poiKNNWeightMatrix;
    double epsilon;
    double C;
    double alpha;
    int knn;
    DenseMatrix geoInfluenceMatrix;
    double[] E;
    protected List<Set<Integer>> userPoisSet;
    KeyValue<Double, Double>[] locationCoordinates;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        this.numPois = this.numItems;
        this.epsilon = this.conf.getDouble("rec.ranking.epsilon", 0.3);
        this.C = this.conf.getDouble("rec.regularization.C", 1.0);
        this.alpha = this.conf.getDouble("rec.regularization.alpha", 0.2);
        this.knn = this.conf.getInt("rec.item.knn", 300);
        this.geoInfluenceMatrix = new DenseMatrix(this.numPois, this.numFactors);
        this.userFactors = new DenseMatrix(this.numUsers, this.numFactors);
        this.geoUserFactors = new DenseMatrix(this.numUsers, this.numFactors);
        this.poiFactors = new DenseMatrix(this.numPois, this.numFactors);
        double initStd = 0.1;
        this.userFactors.init(this.initMean, initStd);
        this.geoUserFactors.init(this.initMean, initStd);
        this.poiFactors.init(this.initMean, initStd);
        this.userPoisSet = this.getUserPoisSet(this.trainMatrix);
        this.locationCoordinates = ((LocationDataAppender)this.getDataModel().getDataAppender()).getLocationAppender();
        this.poiKNNWeightMatrix = this.getPoiKNNWeightMatrix(this.knn);
        this.E = new double[this.numPois + 1];
        for (int i = 1; i <= this.numPois; ++i) {
            this.E[i] = this.E[i - 1] + 1.0 / (double)i;
        }
    }

    @Override
    protected void trainModel() throws LibrecException {
        for (int iter = 1; iter <= this.numIterations; ++iter) {
            this.geoInfluenceMatrix = this.updateGeoInfluenceMatrix();
            this.loss = 0.0;
            DenseMatrix tempUserFactors = new DenseMatrix(this.userFactors);
            DenseMatrix tempGeoUserFactors = new DenseMatrix(this.geoUserFactors);
            DenseMatrix tempPoiFactors = new DenseMatrix(this.poiFactors);
            for (MatrixEntry trainMatrixEntry : this.trainMatrix) {
                double negPoiVectorNorm;
                double posPoiVectorNorm;
                double geoUserVectorNorm;
                double negPredictRating;
                HashMap<Integer, Integer> poisPosList;
                int negPoiIdx;
                Set<Integer> poisSet;
                double negRealRating;
                int incompatibility;
                int userIdx = trainMatrixEntry.row();
                int posPoiIdx = trainMatrixEntry.column();
                double posRealRating = trainMatrixEntry.get();
                int sampleCount = 0;
                double posPredictRating = tempUserFactors.row(userIdx).dot(tempPoiFactors.row(posPoiIdx)) + tempGeoUserFactors.row(userIdx).dot(this.geoInfluenceMatrix.row(posPoiIdx));
                do {
                    negPoiIdx = Randoms.uniform(0, this.numPois);
                    negPredictRating = tempUserFactors.row(userIdx).dot(tempPoiFactors.row(negPoiIdx)) + tempGeoUserFactors.row(userIdx).dot(this.geoInfluenceMatrix.row(negPoiIdx));
                    poisSet = this.userPoisSet.get(userIdx);
                    poisPosList = new HashMap<Integer, Integer>();
                    int[] poiIndices = this.trainMatrix.row(userIdx).getIndices();
                    for (int i = 0; i < poiIndices.length; ++i) {
                        poisPosList.put(poiIndices[i], i);
                    }
                } while ((incompatibility = this.indicator(posRealRating, negRealRating = poisSet.contains(negPoiIdx) ? this.trainMatrix.row(userIdx).getAtPosition((Integer)poisPosList.get(negPoiIdx)) : 0.0) * this.indicator(negPredictRating + this.epsilon, posPredictRating)) != 1 && ++sampleCount <= this.numPois);
                if (incompatibility != 1) continue;
                int lowerBound = this.numPois / sampleCount;
                double s = Maths.logistic(negPredictRating + this.epsilon - posPredictRating);
                this.loss += this.E[lowerBound] * s;
                double uij = s * (1.0 - s);
                double ita = this.E[lowerBound] * uij;
                DenseVector updateUserVec = this.poiFactors.row(negPoiIdx).minus(this.poiFactors.row(posPoiIdx)).times((double)this.learnRate * ita);
                this.userFactors.set(userIdx, this.userFactors.row(userIdx).minus(updateUserVec));
                DenseVector updateGeoUserVec = this.geoInfluenceMatrix.row(negPoiIdx).minus(this.geoInfluenceMatrix.row(posPoiIdx)).times((double)this.learnRate * ita);
                this.geoUserFactors.set(userIdx, this.geoUserFactors.row(userIdx).minus(updateGeoUserVec));
                DenseVector updatePoiVec = this.userFactors.row(userIdx).times((double)this.learnRate * ita);
                this.poiFactors.set(posPoiIdx, this.poiFactors.row(posPoiIdx).plus(updatePoiVec));
                this.poiFactors.set(negPoiIdx, this.poiFactors.row(negPoiIdx).minus(updatePoiVec));
                double userVectorNorm = this.userFactors.row(userIdx).norm(2.0);
                if (userVectorNorm > this.C) {
                    this.userFactors.set(userIdx, this.userFactors.row(userIdx).times(this.C / userVectorNorm));
                }
                if ((geoUserVectorNorm = this.geoUserFactors.row(userIdx).norm(2.0)) > this.alpha * this.C) {
                    this.geoUserFactors.set(userIdx, this.geoUserFactors.row(userIdx).times(this.alpha * this.C / geoUserVectorNorm));
                }
                if ((posPoiVectorNorm = this.poiFactors.row(posPoiIdx).norm(2.0)) > this.C) {
                    this.poiFactors.set(posPoiIdx, this.poiFactors.row(posPoiIdx).times(this.C / posPoiVectorNorm));
                }
                if (!((negPoiVectorNorm = this.poiFactors.row(negPoiIdx).norm(2.0)) > this.C)) continue;
                this.poiFactors.set(negPoiIdx, this.poiFactors.row(negPoiIdx).times(this.C / negPoiVectorNorm));
            }
            if (this.isConverged(iter) && this.earlyStop) break;
            this.updateLRate(iter);
        }
    }

    public SequentialAccessSparseMatrix getPoiKNNWeightMatrix(Integer kNearest) {
        HashBasedTable<Integer, Integer, Double> dataTable = HashBasedTable.create();
        for (int poiIdx = 0; poiIdx < this.numPois; ++poiIdx) {
            List locationNeighbors = new ArrayList(this.numPois);
            KeyValue<Double, Double> location = this.locationCoordinates[poiIdx];
            for (int neighborItemIdx = 0; neighborItemIdx < this.numPois; ++neighborItemIdx) {
                if (poiIdx == neighborItemIdx) continue;
                KeyValue<Double, Double> neighborLocation = this.locationCoordinates[neighborItemIdx];
                double distance = this.getDistance(location.getKey(), location.getValue(), neighborLocation.getKey(), neighborLocation.getValue());
                locationNeighbors.add(new AbstractMap.SimpleImmutableEntry<Integer, Double>(neighborItemIdx, distance));
            }
            locationNeighbors = Lists.sortListTopK(locationNeighbors, false, kNearest);
            for (int index = 0; index < locationNeighbors.size(); ++index) {
                int neighborItemIdx = (Integer)locationNeighbors.get(index).getKey();
                double weight = (Double)locationNeighbors.get(index).getValue() < 0.5 ? 2.0 : 1.0 / (Double)locationNeighbors.get(index).getValue();
                dataTable.put(poiIdx, neighborItemIdx, weight);
            }
        }
        SequentialAccessSparseMatrix poiKNNWeightMatrix = new SequentialAccessSparseMatrix(this.numPois, this.numPois, dataTable);
        HashBasedTable<Integer, Integer, Double> normalizedDataTable = HashBasedTable.create();
        for (int itemIdx = 0; itemIdx < this.numPois; ++itemIdx) {
            double rowSum = poiKNNWeightMatrix.row(itemIdx).sum();
            for (Vector.VectorEntry vectorEntry : poiKNNWeightMatrix.row(itemIdx)) {
                normalizedDataTable.put(itemIdx, vectorEntry.index(), vectorEntry.get() / rowSum);
            }
        }
        poiKNNWeightMatrix = new SequentialAccessSparseMatrix(this.numPois, this.numPois, normalizedDataTable);
        return poiKNNWeightMatrix;
    }

    public DenseMatrix updateGeoInfluenceMatrix() throws LibrecException {
        DenseMatrix geoInfluenceMatrix = new DenseMatrix(this.numPois, this.numFactors);
        for (int poiIdx = 0; poiIdx < this.numPois; ++poiIdx) {
            for (Vector.VectorEntry vectorEntry : this.poiKNNWeightMatrix.row(poiIdx)) {
                geoInfluenceMatrix.set(poiIdx, geoInfluenceMatrix.row(poiIdx).plus(this.poiFactors.row(vectorEntry.index()).times(vectorEntry.get())));
            }
        }
        return geoInfluenceMatrix;
    }

    public double getDistance(double lat1, double long1, double lat2, double long2) {
        double R = 6378137.0;
        lat1 = lat1 * Math.PI / 180.0;
        lat2 = lat2 * Math.PI / 180.0;
        double a = lat1 - lat2;
        double b = (long1 - long2) * Math.PI / 180.0;
        double sina = Math.sin(a / 2.0);
        double sinb = Math.sin(b / 2.0);
        double distance = 2.0 * R * Math.asin(Math.sqrt(sina * sina + Math.cos(lat1) * Math.cos(lat2) * sinb * sinb));
        return distance / 1000.0;
    }

    public int indicator(double i, double j) {
        return i > j ? 1 : 0;
    }

    @Override
    protected double predict(int userIdx, int poiIdx) {
        return this.userFactors.row(userIdx).dot(this.poiFactors.row(poiIdx)) + this.geoUserFactors.row(userIdx).dot(this.geoInfluenceMatrix.row(poiIdx));
    }

    private List<Set<Integer>> getUserPoisSet(SequentialAccessSparseMatrix sparseMatrix) {
        ArrayList<Set<Integer>> userPoisSet = new ArrayList<Set<Integer>>();
        for (int userIdx = 0; userIdx < this.numUsers; ++userIdx) {
            int[] itemIndexes = sparseMatrix.row(userIdx).getIndices();
            Integer[] inputBoxed = ArrayUtils.toObject(itemIndexes);
            List<Integer> itemList = Arrays.asList(inputBoxed);
            userPoisSet.add(new HashSet<Integer>(itemList));
        }
        return userPoisSet;
    }
}

