/*
 * Decompiled with CFR 0.152.
 */
package net.librec.data.splitter;

import java.util.ArrayList;
import java.util.Collections;
import net.librec.common.LibrecException;
import net.librec.conf.Configuration;
import net.librec.data.DataConvertor;
import net.librec.data.convertor.ArffDataConvertor;
import net.librec.data.splitter.AbstractDataSplitter;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.SequentialAccessSparseMatrix;
import net.librec.math.structure.SequentialSparseVector;
import net.librec.math.structure.Vector;
import net.librec.util.RatingContext;
import org.apache.commons.lang.StringUtils;

public class GivenNDataSplitter
extends AbstractDataSplitter {
    private SequentialAccessSparseMatrix preferenceMatrix;
    private SequentialAccessSparseMatrix datetimeMatrix;

    public GivenNDataSplitter() {
    }

    public GivenNDataSplitter(DataConvertor dataConvertor, Configuration conf) {
        this.dataConvertor = dataConvertor;
        this.conf = conf;
    }

    @Override
    public void splitData() throws LibrecException {
        if (null == this.preferenceMatrix) {
            this.preferenceMatrix = this.dataConvertor.getPreferenceMatrix(this.conf);
            if (!(this.dataConvertor instanceof ArffDataConvertor) && StringUtils.equals(this.conf.get("data.column.format"), "UIRT")) {
                this.datetimeMatrix = this.dataConvertor.getDatetimeMatrix();
            }
        }
        String splitter = this.conf.get("data.splitter.givenn");
        switch (splitter.toLowerCase()) {
            case "user": {
                try {
                    this.getGivenNByUser(Integer.parseInt(this.conf.get("data.splitter.givenn.n")));
                }
                catch (Exception e) {
                    e.printStackTrace();
                }
                break;
            }
            case "item": {
                try {
                    this.getGivenNByItem(Integer.parseInt(this.conf.get("data.splitter.givenn.n")));
                }
                catch (Exception e) {
                    e.printStackTrace();
                }
                break;
            }
            case "userdate": {
                try {
                    this.getGivenNByUserDate(Integer.parseInt(this.conf.get("data.splitter.givenn.n")));
                }
                catch (Exception e) {
                    e.printStackTrace();
                }
                break;
            }
            case "itemdate": {
                try {
                    this.getGivenNByItemDate(Integer.parseInt(this.conf.get("data.splitter.givenn.n")));
                }
                catch (Exception e) {
                    e.printStackTrace();
                }
                break;
            }
        }
    }

    public void getGivenNByUser(int numGiven) throws Exception {
        if (numGiven > 0) {
            this.trainMatrix = new SequentialAccessSparseMatrix(this.preferenceMatrix);
            this.testMatrix = new SequentialAccessSparseMatrix(this.preferenceMatrix);
            int rowSize = this.preferenceMatrix.rowSize();
            for (int rowIndex = 0; rowIndex < rowSize; ++rowIndex) {
                int numRated = this.preferenceMatrix.row(rowIndex).getNumEntries();
                if (numRated > numGiven) {
                    int[] givenPositions = Randoms.nextIntArray(numGiven, numRated);
                    int testColumnPosition = 0;
                    for (int columnPosition = 0; columnPosition < numRated; ++columnPosition) {
                        if (testColumnPosition < givenPositions.length && givenPositions[testColumnPosition] == columnPosition) {
                            this.testMatrix.setAtColumnPosition(rowIndex, columnPosition, 0.0);
                            ++testColumnPosition;
                            continue;
                        }
                        this.trainMatrix.setAtColumnPosition(rowIndex, columnPosition, 0.0);
                    }
                    continue;
                }
                for (Vector.VectorEntry vectorEntry : this.preferenceMatrix.row(rowIndex)) {
                    this.testMatrix.setAtColumnPosition(rowIndex, vectorEntry.position(), 0.0);
                }
            }
            this.trainMatrix.reshape();
            this.testMatrix.reshape();
        }
    }

    public void getGivenNByUserDate(int numGiven) {
        if (numGiven > 0) {
            this.trainMatrix = new SequentialAccessSparseMatrix(this.preferenceMatrix);
            this.testMatrix = new SequentialAccessSparseMatrix(this.preferenceMatrix);
            int rowSize = this.preferenceMatrix.rowSize();
            for (int rowIndex = 0; rowIndex < rowSize; ++rowIndex) {
                SequentialSparseVector itemRatingVector = this.preferenceMatrix.row(rowIndex);
                if (itemRatingVector.getNumEntries() < 1) continue;
                ArrayList<RatingContext> itemRatingList = new ArrayList<RatingContext>(itemRatingVector.getNumEntries());
                for (Vector.VectorEntry vectorEntry : itemRatingVector) {
                    itemRatingList.add(new RatingContext(rowIndex, vectorEntry.position(), (long)vectorEntry.get()));
                }
                Collections.sort(itemRatingList);
                for (int index = 0; index < itemRatingList.size(); ++index) {
                    if (index < numGiven) {
                        this.testMatrix.setAtColumnPosition(rowIndex, ((RatingContext)itemRatingList.get(index)).getItem(), 0.0);
                        continue;
                    }
                    this.trainMatrix.setAtColumnPosition(rowIndex, ((RatingContext)itemRatingList.get(index)).getItem(), 0.0);
                }
            }
            this.trainMatrix.reshape();
            this.testMatrix.reshape();
        }
    }

    public void getGivenNByItem(int numGiven) throws Exception {
        if (numGiven > 0) {
            this.trainMatrix = new SequentialAccessSparseMatrix(this.preferenceMatrix);
            this.testMatrix = new SequentialAccessSparseMatrix(this.preferenceMatrix);
            int columnSize = this.preferenceMatrix.columnSize();
            for (int columnIndex = 0; columnIndex < columnSize; ++columnIndex) {
                int numRated = this.preferenceMatrix.column(columnIndex).getNumEntries();
                if (numRated > numGiven) {
                    int[] givenPositions = Randoms.nextIntArray(numGiven, numRated);
                    int testRowPosition = 0;
                    for (int rowPosition = 0; rowPosition < numRated; ++rowPosition) {
                        if (testRowPosition < givenPositions.length && givenPositions[testRowPosition] == rowPosition) {
                            this.testMatrix.setAtRowPosition(rowPosition, columnIndex, 0.0);
                            ++testRowPosition;
                            continue;
                        }
                        this.trainMatrix.setAtRowPosition(rowPosition, columnIndex, 0.0);
                    }
                    continue;
                }
                for (int rowPosition = 0; rowPosition < numRated; ++rowPosition) {
                    this.testMatrix.setAtRowPosition(rowPosition, columnIndex, 0.0);
                }
            }
            this.trainMatrix.reshape();
            this.testMatrix.reshape();
        }
    }

    public void getGivenNByItemDate(int numGiven) {
        if (numGiven > 0) {
            this.trainMatrix = new SequentialAccessSparseMatrix(this.preferenceMatrix);
            this.testMatrix = new SequentialAccessSparseMatrix(this.preferenceMatrix);
            int columnSize = this.preferenceMatrix.columnSize();
            for (int columnIndex = 0; columnIndex < columnSize; ++columnIndex) {
                SequentialSparseVector userRatingVector = this.preferenceMatrix.column(columnIndex);
                if (userRatingVector.getNumEntries() < 1) continue;
                ArrayList<RatingContext> ratingContexts = new ArrayList<RatingContext>(userRatingVector.getNumEntries());
                for (Vector.VectorEntry vectorEntry : userRatingVector) {
                    ratingContexts.add(new RatingContext(vectorEntry.position(), columnIndex, (long)this.datetimeMatrix.getAtRowPosition(vectorEntry.position(), columnIndex)));
                }
                Collections.sort(ratingContexts);
                for (int rowPosition = 0; rowPosition < ratingContexts.size(); ++rowPosition) {
                    RatingContext ratingContext = (RatingContext)ratingContexts.get(rowPosition);
                    if (rowPosition < numGiven) {
                        this.testMatrix.setAtRowPosition(ratingContext.getUser(), columnIndex, 0.0);
                        continue;
                    }
                    this.trainMatrix.setAtRowPosition(ratingContext.getUser(), columnIndex, 0.0);
                }
            }
            this.trainMatrix.reshape();
            this.testMatrix.reshape();
        }
    }
}

