/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.ext;

import com.clearspring.analytics.util.Lists;
import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Maps;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import net.librec.common.LibrecException;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.MatrixEntry;
import net.librec.math.structure.SequentialAccessSparseMatrix;
import net.librec.math.structure.SequentialSparseVector;
import net.librec.math.structure.Vector;
import net.librec.recommender.MatrixRecommender;

public class AssociationRuleRecommender
extends MatrixRecommender {
    private SequentialAccessSparseMatrix associations;
    private DenseMatrix userItemRates;

    @Override
    protected void setup() throws LibrecException {
        super.setup();
        this.userItemRates = new DenseMatrix(this.trainMatrix.rowSize(), this.trainMatrix.columnSize());
        for (MatrixEntry entry : this.trainMatrix) {
            this.userItemRates.set(entry.row(), entry.column(), entry.get());
        }
    }

    @Override
    protected void trainModel() throws LibrecException {
        List userItemsPosList = Lists.newArrayList();
        for (int userIdx = 0; userIdx < this.numUsers; ++userIdx) {
            HashMap<Integer, Integer> itemIndexPosMap = Maps.newHashMap();
            int[] itemIndices = this.trainMatrix.row(userIdx).getIndices();
            for (int i = 0; i < itemIndices.length; ++i) {
                itemIndexPosMap.put(itemIndices[i], i);
            }
            userItemsPosList.add(itemIndexPosMap);
        }
        HashBasedTable<Integer, Integer, Double> associationTable = HashBasedTable.create(this.numItems, this.numItems);
        for (int itemIdx = 0; itemIdx < this.numItems; ++itemIdx) {
            SequentialSparseVector userRatingsVector = this.trainMatrix.column(itemIdx);
            int userCount = userRatingsVector.size();
            for (int assoItemIdx = 0; assoItemIdx < this.numItems; ++assoItemIdx) {
                int count = 0;
                for (Vector.VectorEntry vectorEntry : userRatingsVector) {
                    int userIdx = vectorEntry.index();
                    Map currItemIndexPosMap = (Map)userItemsPosList.get(userIdx);
                    if (currItemIndexPosMap.get(assoItemIdx) == null) continue;
                    ++count;
                }
                if (count <= 0) continue;
                double confidence = ((double)count + 0.0) / (double)userCount;
                associationTable.put(itemIdx, assoItemIdx, confidence);
            }
        }
        this.associations = new SequentialAccessSparseMatrix(this.numItems, this.numItems, associationTable);
    }

    @Override
    protected double predict(int userIdx, int itemIdx) throws LibrecException {
        double predictRatings = 0.0;
        for (Vector.VectorEntry mapEntry : this.associations.column(itemIdx)) {
            int assoItemIdx = mapEntry.index();
            double support = mapEntry.get();
            double value = this.userItemRates.get(userIdx, assoItemIdx);
            predictRatings += value * support;
        }
        return predictRatings;
    }
}

