/*
 * Decompiled with CFR 0.152.
 */
package net.librec.data.splitter;

import java.util.ArrayList;
import java.util.Collections;
import net.librec.common.LibrecException;
import net.librec.conf.Configuration;
import net.librec.data.DataConvertor;
import net.librec.data.convertor.ArffDataConvertor;
import net.librec.data.splitter.AbstractDataSplitter;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.MatrixEntry;
import net.librec.math.structure.SequentialAccessSparseMatrix;
import net.librec.math.structure.SequentialSparseVector;
import net.librec.math.structure.Vector;
import net.librec.util.RatingContext;
import org.apache.commons.lang.StringUtils;

public class RatioDataSplitter
extends AbstractDataSplitter {
    private SequentialAccessSparseMatrix datetimeMatrix;

    public RatioDataSplitter() {
    }

    public RatioDataSplitter(DataConvertor dataConvertor, Configuration conf) {
        this.dataConvertor = dataConvertor;
        this.conf = conf;
    }

    @Override
    public void splitData() throws LibrecException {
        if (null == this.preferenceMatrix) {
            this.preferenceMatrix = this.dataConvertor.getPreferenceMatrix(this.conf);
            if (!(this.dataConvertor instanceof ArffDataConvertor) && StringUtils.equals(this.conf.get("data.column.format"), "UIRT")) {
                this.datetimeMatrix = this.dataConvertor.getDatetimeMatrix();
            }
        }
        String splitter = this.conf.get("data.splitter.ratio");
        switch (splitter.toLowerCase()) {
            case "rating": {
                double ratio = Double.parseDouble(this.conf.get("data.splitter.trainset.ratio"));
                this.getRatioByRating(ratio);
                break;
            }
            case "user": {
                double ratio = Double.parseDouble(this.conf.get("data.splitter.trainset.ratio"));
                this.getRatioByUser(ratio);
                break;
            }
            case "userfixed": {
                double ratio = Double.parseDouble(this.conf.get("data.splitter.trainset.ratio"));
                this.getFixedRatioByUser(ratio);
                break;
            }
            case "item": {
                double ratio = Double.parseDouble(this.conf.get("data.splitter.trainset.ratio"));
                this.getRatioByItem(ratio);
                break;
            }
            case "valid": {
                double trainRatio = Double.parseDouble(this.conf.get("data.splitter.trainset.ratio"));
                double validationRaito = Double.parseDouble(this.conf.get("data.splitter.validset.ratio"));
                this.getRatio(trainRatio, validationRaito);
                break;
            }
            case "ratingdate": {
                double ratio = Double.parseDouble(this.conf.get("data.splitter.trainset.ratio"));
                this.getRatioByRatingDate(ratio);
                break;
            }
            case "userdate": {
                double ratio = Double.parseDouble(this.conf.get("data.splitter.trainset.ratio"));
                this.getRatioByUserDate(ratio);
                break;
            }
            case "itemdate": {
                double ratio = Double.parseDouble(this.conf.get("data.splitter.trainset.ratio"));
                this.getRatioByItemDate(ratio);
                break;
            }
        }
    }

    public void getRatioByRating(double ratio) {
        if (ratio > 0.0 && ratio < 1.0) {
            this.testMatrix = new SequentialAccessSparseMatrix(this.preferenceMatrix);
            this.trainMatrix = new SequentialAccessSparseMatrix(this.preferenceMatrix);
            for (MatrixEntry matrixEntry : this.preferenceMatrix) {
                double rdm = Randoms.uniform();
                if (rdm < ratio) {
                    this.testMatrix.setAtColumnPosition(matrixEntry.row(), matrixEntry.columnPosition(), 0.0);
                    continue;
                }
                this.trainMatrix.setAtColumnPosition(matrixEntry.row(), matrixEntry.columnPosition(), 0.0);
            }
            this.testMatrix.reshape();
            this.trainMatrix.reshape();
        }
    }

    public void getRatioByRatingDate(double ratio) {
        if (ratio > 0.0 && ratio < 1.0) {
            this.testMatrix = new SequentialAccessSparseMatrix(this.preferenceMatrix);
            this.trainMatrix = new SequentialAccessSparseMatrix(this.preferenceMatrix);
            ArrayList<RatingContext> ratingContexts = new ArrayList<RatingContext>(this.datetimeMatrix.size());
            for (MatrixEntry matrixEntry : this.preferenceMatrix) {
                ratingContexts.add(new RatingContext(matrixEntry.row(), matrixEntry.columnPosition(), (long)this.datetimeMatrix.getAtColumnPosition(matrixEntry.row(), matrixEntry.columnPosition())));
            }
            Collections.sort(ratingContexts);
            int trainSize = (int)((double)ratingContexts.size() * ratio);
            for (int index = 0; index < ratingContexts.size(); ++index) {
                RatingContext rc = (RatingContext)ratingContexts.get(index);
                int rowIndex = rc.getUser();
                int columnPosition = rc.getItem();
                if (index < trainSize) {
                    this.testMatrix.setAtColumnPosition(rowIndex, columnPosition, 0.0);
                    continue;
                }
                this.trainMatrix.setAtColumnPosition(rowIndex, columnPosition, 0.0);
            }
            this.trainMatrix.reshape();
            this.testMatrix.reshape();
        }
    }

    public void getRatioByUser(double ratio) {
        if (ratio > 0.0 && ratio < 1.0) {
            this.trainMatrix = new SequentialAccessSparseMatrix(this.preferenceMatrix);
            this.testMatrix = new SequentialAccessSparseMatrix(this.preferenceMatrix);
            int rowSize = this.preferenceMatrix.rowSize();
            for (int rowIndex = 0; rowIndex < rowSize; ++rowIndex) {
                for (Vector.VectorEntry vectorEntry : this.preferenceMatrix.row(rowIndex)) {
                    if (Randoms.uniform() < ratio) {
                        this.testMatrix.setAtColumnPosition(rowIndex, vectorEntry.position(), 0.0);
                        continue;
                    }
                    this.trainMatrix.setAtColumnPosition(rowIndex, vectorEntry.position(), 0.0);
                }
            }
        }
        this.testMatrix.reshape();
        this.trainMatrix.reshape();
    }

    public void getFixedRatioByUser(double ratio) {
        if (ratio > 0.0 && ratio < 1.0) {
            this.trainMatrix = new SequentialAccessSparseMatrix(this.preferenceMatrix);
            this.testMatrix = new SequentialAccessSparseMatrix(this.preferenceMatrix);
            int rowSize = this.preferenceMatrix.rowSize();
            for (int rowIndex = 0; rowIndex < rowSize; ++rowIndex) {
                int numRated = this.preferenceMatrix.row(rowIndex).getNumEntries();
                int numRatio = (int)Math.floor((double)numRated * (1.0 - ratio));
                if (numRatio < 1) continue;
                try {
                    int[] givenPositions = Randoms.nextIntArray(numRatio, numRated);
                    int testColumnPosition = 0;
                    for (int columnPosition = 0; columnPosition < numRated; ++columnPosition) {
                        if (testColumnPosition < givenPositions.length && givenPositions[testColumnPosition] == columnPosition) {
                            this.testMatrix.setAtColumnPosition(rowIndex, testColumnPosition, 0.0);
                            ++testColumnPosition;
                            continue;
                        }
                        this.trainMatrix.setAtColumnPosition(rowIndex, testColumnPosition, 0.0);
                    }
                    continue;
                }
                catch (Exception e) {
                    this.LOG.error("This error should not happen because k cannot be outside of the range if ratio is " + ratio);
                }
            }
            this.testMatrix.reshape();
            this.trainMatrix.reshape();
        }
    }

    public void getRatioByUserDate(double ratio) {
        if (ratio > 0.0 && ratio < 1.0) {
            this.trainMatrix = new SequentialAccessSparseMatrix(this.preferenceMatrix);
            this.testMatrix = new SequentialAccessSparseMatrix(this.preferenceMatrix);
            int rowSize = this.preferenceMatrix.rowSize();
            for (int rowIndex = 0; rowIndex < rowSize; ++rowIndex) {
                SequentialSparseVector itemRatingVector = this.preferenceMatrix.row(rowIndex);
                if (itemRatingVector.getNumEntries() < 1) continue;
                ArrayList<RatingContext> itemRatingList = new ArrayList<RatingContext>(itemRatingVector.getNumEntries());
                for (Vector.VectorEntry vectorEntry : itemRatingVector) {
                    itemRatingList.add(new RatingContext(rowIndex, vectorEntry.position(), (long)vectorEntry.get()));
                }
                int trainSize = (int)((double)itemRatingList.size() * ratio);
                Collections.sort(itemRatingList);
                for (int index = 0; index < itemRatingList.size(); ++index) {
                    if (index < trainSize) {
                        this.testMatrix.setAtColumnPosition(rowIndex, ((RatingContext)itemRatingList.get(index)).getItem(), 0.0);
                        continue;
                    }
                    this.trainMatrix.setAtColumnPosition(rowIndex, ((RatingContext)itemRatingList.get(index)).getItem(), 0.0);
                }
            }
            this.trainMatrix.reshape();
            this.testMatrix.reshape();
        }
    }

    public void getRatioByItem(double ratio) {
        if (ratio > 0.0 && ratio < 1.0) {
            this.trainMatrix = new SequentialAccessSparseMatrix(this.preferenceMatrix);
            this.testMatrix = new SequentialAccessSparseMatrix(this.preferenceMatrix);
            int columnSize = this.preferenceMatrix.columnSize();
            for (int columnIndex = 0; columnIndex < columnSize; ++columnIndex) {
                for (Vector.VectorEntry vectorEntry : this.preferenceMatrix.column(columnIndex)) {
                    if (Randoms.uniform() < ratio) {
                        this.testMatrix.setAtRowPosition(vectorEntry.position(), columnIndex, 0.0);
                        continue;
                    }
                    this.trainMatrix.setAtRowPosition(vectorEntry.position(), columnIndex, 0.0);
                }
            }
            this.trainMatrix.reshape();
            this.testMatrix.reshape();
        }
    }

    public void getRatioByItemDate(double ratio) {
        if (ratio > 0.0 && ratio < 1.0) {
            this.trainMatrix = new SequentialAccessSparseMatrix(this.preferenceMatrix);
            this.testMatrix = new SequentialAccessSparseMatrix(this.preferenceMatrix);
            int columnSize = this.preferenceMatrix.columnSize();
            for (int columnIndex = 0; columnIndex < columnSize; ++columnIndex) {
                SequentialSparseVector userRatingVector = this.preferenceMatrix.column(columnIndex);
                if (userRatingVector.getNumEntries() < 1) continue;
                ArrayList<RatingContext> ratingContexts = new ArrayList<RatingContext>(userRatingVector.getNumEntries());
                for (Vector.VectorEntry vectorEntry : userRatingVector) {
                    ratingContexts.add(new RatingContext(vectorEntry.position(), columnIndex, (long)this.datetimeMatrix.getAtRowPosition(vectorEntry.position(), columnIndex)));
                }
                Collections.sort(ratingContexts);
                int trainSize = (int)((double)ratingContexts.size() * ratio);
                for (int rowPosition = 0; rowPosition < ratingContexts.size(); ++rowPosition) {
                    RatingContext ratingContext = (RatingContext)ratingContexts.get(rowPosition);
                    if (rowPosition < trainSize) {
                        this.testMatrix.setAtRowPosition(ratingContext.getUser(), columnIndex, 0.0);
                        continue;
                    }
                    this.trainMatrix.setAtRowPosition(ratingContext.getUser(), columnIndex, 0.0);
                }
            }
            this.testMatrix.reshape();
            this.trainMatrix.reshape();
        }
    }

    public void getRatio(double trainRatio, double validationRatio) {
        if (trainRatio > 0.0 && validationRatio > 0.0 && trainRatio + validationRatio < 1.0) {
            this.trainMatrix = new SequentialAccessSparseMatrix(this.preferenceMatrix);
            this.validationMatrix = new SequentialAccessSparseMatrix(this.preferenceMatrix);
            this.testMatrix = new SequentialAccessSparseMatrix(this.preferenceMatrix);
            for (MatrixEntry matrixEntry : this.preferenceMatrix) {
                double rdm = Randoms.uniform();
                if (rdm < trainRatio) {
                    this.validationMatrix.setAtColumnPosition(matrixEntry.row(), matrixEntry.columnPosition(), 0.0);
                    this.testMatrix.setAtColumnPosition(matrixEntry.row(), matrixEntry.columnPosition(), 0.0);
                    continue;
                }
                if (rdm < trainRatio + validationRatio) {
                    this.trainMatrix.setAtColumnPosition(matrixEntry.row(), matrixEntry.columnPosition(), 0.0);
                    this.testMatrix.setAtColumnPosition(matrixEntry.row(), matrixEntry.columnPosition(), 0.0);
                    continue;
                }
                this.trainMatrix.setAtColumnPosition(matrixEntry.row(), matrixEntry.columnPosition(), 0.0);
                this.validationMatrix.setAtColumnPosition(matrixEntry.row(), matrixEntry.columnPosition(), 0.0);
            }
            this.trainMatrix.reshape();
            this.validationMatrix.reshape();
            this.testMatrix.reshape();
        }
    }
}

