/*
 * Decompiled with CFR 0.152.
 */
package net.librec.recommender.content;

import java.beans.ConstructorProperties;
import java.util.ArrayList;
import java.util.List;
import net.librec.recommender.content.ConvMFDocumentProvider;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.text.tokenization.tokenizer.Tokenizer;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.cpu.nativecpu.NDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.primitives.Pair;

public class ConvMFDocumentDataSetIterator
implements DataSetIterator {
    private static final String UNKNOWN_WORD_SENTINEL = "UNKNOWN_WORD_SENTINEL";
    private ConvMFDocumentProvider documentProvider = null;
    private WordVectors wordVectors;
    private TokenizerFactory tokenizerFactory;
    private UnknownWordHandling unknownWordHandling;
    private boolean useNormalizedWordVectors;
    private int minibatchSize;
    private int maxSentenceLength;
    private boolean sentencesAlongHeight;
    private DataSetPreProcessor dataSetPreProcessor;
    private int wordVectorSize;
    private INDArray unknown;
    private int cursor = 0;
    private int labelSize;
    private int numDoc;

    private ConvMFDocumentDataSetIterator(Builder builder) {
        this.documentProvider = builder.documentProvider;
        this.wordVectors = builder.wordVectors;
        this.tokenizerFactory = builder.tokenizerFactory;
        this.unknownWordHandling = builder.unknownWordHandling;
        this.useNormalizedWordVectors = builder.useNormalizedWordVectors;
        this.minibatchSize = builder.minibatchSize;
        this.maxSentenceLength = builder.maxSentenceLength;
        this.sentencesAlongHeight = builder.sentencesAlongHeight;
        this.dataSetPreProcessor = builder.dataSetPreProcessor;
        this.labelSize = this.documentProvider.getLabelSize();
        this.numDoc = this.documentProvider.getNumDoc();
        if (this.unknownWordHandling == UnknownWordHandling.UseUnknownVector) {
            if (this.useNormalizedWordVectors) {
                this.wordVectors.getWordVectorMatrixNormalized(this.wordVectors.getUNK());
            } else {
                this.wordVectors.getWordVectorMatrix(this.wordVectors.getUNK());
            }
        }
        this.wordVectorSize = this.wordVectors.getWordVector(this.wordVectors.vocab().wordAtIndex(0)).length;
    }

    public INDArray loadSingleSentence(String sentence) {
        List<String> tokens = this.tokenizeSentence(sentence);
        int[] featuresShape = new int[]{1, 1, 0, 0};
        if (this.sentencesAlongHeight) {
            featuresShape[2] = Math.min(this.maxSentenceLength, tokens.size());
            featuresShape[3] = this.wordVectorSize;
        } else {
            featuresShape[2] = this.wordVectorSize;
            featuresShape[3] = Math.min(this.maxSentenceLength, tokens.size());
        }
        INDArray features = Nd4j.create((int[])featuresShape);
        int length = this.sentencesAlongHeight ? featuresShape[2] : featuresShape[3];
        for (int i = 0; i < length; ++i) {
            INDArray vector = this.getVector(tokens.get(i));
            INDArrayIndex[] indices = new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.point((int)0), null, null};
            if (this.sentencesAlongHeight) {
                indices[2] = NDArrayIndex.point((int)i);
                indices[3] = NDArrayIndex.all();
            } else {
                indices[2] = NDArrayIndex.all();
                indices[3] = NDArrayIndex.point((int)i);
            }
            features.put(indices, vector);
        }
        return features;
    }

    private INDArray getVector(String word) {
        INDArray vector = this.unknownWordHandling == UnknownWordHandling.UseUnknownVector && word == UNKNOWN_WORD_SENTINEL ? this.unknown : (this.useNormalizedWordVectors ? this.wordVectors.getWordVectorMatrixNormalized(word) : this.wordVectors.getWordVectorMatrix(word));
        return vector;
    }

    private List<String> tokenizeSentence(String sentence) {
        Tokenizer t = this.tokenizerFactory.create(sentence);
        ArrayList<String> tokens = new ArrayList<String>();
        while (t.hasMoreTokens()) {
            String token = t.nextToken();
            if (!this.wordVectors.hasWord(token)) continue;
            tokens.add(token);
        }
        return tokens;
    }

    public boolean hasNext() {
        if (this.documentProvider == null) {
            throw new UnsupportedOperationException("Cannot do next/hasNext without a sentence provider");
        }
        return this.documentProvider.hasNext();
    }

    public DataSet next() {
        return this.next(this.minibatchSize);
    }

    public DataSet next(int num) {
        int sentenceLength;
        int currMinibatchSize;
        if (this.documentProvider == null) {
            throw new UnsupportedOperationException("Cannot do next/hasNext without a sentence provider");
        }
        ArrayList<Pair> tokenizedSentences = new ArrayList<Pair>(num);
        int maxLength = -1;
        int minLength = Integer.MAX_VALUE;
        for (currMinibatchSize = 0; currMinibatchSize < num && this.documentProvider.hasNext(); ++currMinibatchSize) {
            Pair<String, double[]> p = this.documentProvider.nextSentence();
            List<String> tokens = this.tokenizeSentence((String)p.getFirst());
            maxLength = Math.max(maxLength, tokens.size());
            tokenizedSentences.add(new Pair(tokens, p.getSecond()));
        }
        if (this.maxSentenceLength > 0 && maxLength > this.maxSentenceLength) {
            maxLength = this.maxSentenceLength;
        }
        currMinibatchSize = tokenizedSentences.size();
        double[][] labelsData = new double[currMinibatchSize][this.labelSize];
        for (int i = 0; i < tokenizedSentences.size(); ++i) {
            double[] labelArray = (double[])((Pair)tokenizedSentences.get(i)).getSecond();
            labelsData[i] = labelArray;
        }
        NDArray labels = new NDArray(labelsData);
        int[] featuresShape = new int[]{currMinibatchSize, 1, 0, 0};
        if (this.sentencesAlongHeight) {
            featuresShape[2] = maxLength;
            featuresShape[3] = this.wordVectorSize;
        } else {
            featuresShape[2] = this.wordVectorSize;
            featuresShape[3] = maxLength;
        }
        INDArray features = Nd4j.create((int[])featuresShape);
        for (int i = 0; i < currMinibatchSize; ++i) {
            List currSentence = (List)((Pair)tokenizedSentences.get(i)).getFirst();
            for (sentenceLength = 0; sentenceLength < currSentence.size() && sentenceLength < this.maxSentenceLength; ++sentenceLength) {
                INDArray vector = this.getVector((String)currSentence.get(sentenceLength));
                INDArrayIndex[] indices = new INDArrayIndex[]{NDArrayIndex.point((int)i), NDArrayIndex.point((int)0), null, null};
                if (this.sentencesAlongHeight) {
                    indices[2] = NDArrayIndex.point((int)sentenceLength);
                    indices[3] = NDArrayIndex.all();
                } else {
                    indices[2] = NDArrayIndex.all();
                    indices[3] = NDArrayIndex.point((int)sentenceLength);
                }
                features.put(indices, vector);
            }
        }
        INDArray featuresMask = null;
        if (minLength != maxLength) {
            featuresMask = Nd4j.create((int)currMinibatchSize, (int)maxLength);
            for (int i = 0; i < currMinibatchSize; ++i) {
                sentenceLength = ((List)((Pair)tokenizedSentences.get(i)).getFirst()).size();
                if (sentenceLength >= maxLength) {
                    featuresMask.getRow(i).assign((Number)1.0);
                    continue;
                }
                featuresMask.get(new INDArrayIndex[]{NDArrayIndex.point((int)i), NDArrayIndex.interval((int)0, (int)sentenceLength)}).assign((Number)1.0);
            }
        }
        DataSet ds = new DataSet(features, (INDArray)labels, featuresMask, null);
        if (this.dataSetPreProcessor != null) {
            this.dataSetPreProcessor.preProcess((org.nd4j.linalg.dataset.api.DataSet)ds);
        }
        this.cursor += ds.numExamples();
        return ds;
    }

    public INDArray loadAllDocuments() {
        this.reset();
        ArrayList<Pair> tokenizedSentences = new ArrayList<Pair>(this.numDoc);
        int maxLength = -1;
        int currMinibatchSize = 0;
        while (this.documentProvider.hasNext()) {
            Pair<String, double[]> p = this.documentProvider.nextSentence();
            List<String> tokens = this.tokenizeSentence((String)p.getFirst());
            maxLength = Math.max(maxLength, tokens.size());
            tokenizedSentences.add(new Pair(tokens, p.getSecond()));
            ++currMinibatchSize;
        }
        if (this.maxSentenceLength > 0 && maxLength > this.maxSentenceLength) {
            maxLength = this.maxSentenceLength;
        }
        currMinibatchSize = tokenizedSentences.size();
        int[] featuresShape = new int[]{currMinibatchSize, 1, 0, 0};
        if (this.sentencesAlongHeight) {
            featuresShape[2] = maxLength;
            featuresShape[3] = this.wordVectorSize;
        } else {
            featuresShape[2] = this.wordVectorSize;
            featuresShape[3] = maxLength;
        }
        INDArray features = Nd4j.create((int[])featuresShape);
        for (int i = 0; i < currMinibatchSize; ++i) {
            List currSentence = (List)((Pair)tokenizedSentences.get(i)).getFirst();
            for (int sentenceLength = 0; sentenceLength < currSentence.size() && sentenceLength < this.maxSentenceLength; ++sentenceLength) {
                INDArray vector = this.getVector((String)currSentence.get(sentenceLength));
                INDArrayIndex[] indices = new INDArrayIndex[]{NDArrayIndex.point((int)i), NDArrayIndex.point((int)0), null, null};
                if (this.sentencesAlongHeight) {
                    indices[2] = NDArrayIndex.point((int)sentenceLength);
                    indices[3] = NDArrayIndex.all();
                } else {
                    indices[2] = NDArrayIndex.all();
                    indices[3] = NDArrayIndex.point((int)sentenceLength);
                }
                features.put(indices, vector);
            }
        }
        return features;
    }

    public int totalExamples() {
        return this.documentProvider.totalNumSentences();
    }

    public int inputColumns() {
        return this.wordVectorSize;
    }

    public boolean resetSupported() {
        return true;
    }

    public boolean asyncSupported() {
        return true;
    }

    public void reset() {
        this.cursor = 0;
        this.documentProvider.reset();
    }

    public int batch() {
        return this.minibatchSize;
    }

    public int cursor() {
        return this.cursor;
    }

    public int numExamples() {
        return this.totalExamples();
    }

    public void setPreProcessor(DataSetPreProcessor preProcessor) {
        this.dataSetPreProcessor = preProcessor;
    }

    public DataSetPreProcessor getPreProcessor() {
        return this.dataSetPreProcessor;
    }

    public void remove() {
        throw new UnsupportedOperationException("Not supported");
    }

    public int totalOutcomes() {
        return this.labelSize;
    }

    public List<String> getLabels() {
        return null;
    }

    @ConstructorProperties(value={"documentProvider", "wordVectors", "tokenizerFactory", "unknownWordHandling", "useNormalizedWordVectors", "minibatchSize", "maxSentenceLength", "sentencesAlongHeight", "dataSetPreProcessor", "wordVectorSize", "unknown", "cursor", "labelSize"})
    public ConvMFDocumentDataSetIterator(ConvMFDocumentProvider documentProvider, WordVectors wordVectors, TokenizerFactory tokenizerFactory, UnknownWordHandling unknownWordHandling, boolean useNormalizedWordVectors, int minibatchSize, int maxSentenceLength, boolean sentencesAlongHeight, DataSetPreProcessor dataSetPreProcessor, int wordVectorSize, INDArray unknown, int cursor, int labelSize) {
        this.documentProvider = documentProvider;
        this.wordVectors = wordVectors;
        this.tokenizerFactory = tokenizerFactory;
        this.unknownWordHandling = unknownWordHandling;
        this.useNormalizedWordVectors = useNormalizedWordVectors;
        this.minibatchSize = minibatchSize;
        this.maxSentenceLength = maxSentenceLength;
        this.sentencesAlongHeight = sentencesAlongHeight;
        this.dataSetPreProcessor = dataSetPreProcessor;
        this.wordVectorSize = wordVectorSize;
        this.unknown = unknown;
        this.cursor = cursor;
        this.labelSize = labelSize;
    }

    public static enum UnknownWordHandling {
        RemoveWord,
        UseUnknownVector;

    }

    public static class Builder {
        private ConvMFDocumentProvider documentProvider = null;
        private WordVectors wordVectors;
        private TokenizerFactory tokenizerFactory = new DefaultTokenizerFactory();
        private UnknownWordHandling unknownWordHandling = UnknownWordHandling.RemoveWord;
        private boolean useNormalizedWordVectors = true;
        private int maxSentenceLength = -1;
        private int minibatchSize = 32;
        private boolean sentencesAlongHeight = true;
        private DataSetPreProcessor dataSetPreProcessor;

        public Builder documentProvider(ConvMFDocumentProvider documentProvider) {
            this.documentProvider = documentProvider;
            return this;
        }

        public Builder wordVectors(WordVectors wordVectors) {
            this.wordVectors = wordVectors;
            return this;
        }

        public Builder tokenizerFactory(TokenizerFactory tokenizerFactory) {
            this.tokenizerFactory = tokenizerFactory;
            return this;
        }

        public Builder unknownWordHandling(UnknownWordHandling unknownWordHandling) {
            this.unknownWordHandling = unknownWordHandling;
            return this;
        }

        public Builder minibatchSize(int minibatchSize) {
            this.minibatchSize = minibatchSize;
            return this;
        }

        public Builder useNormalizedWordVectors(boolean useNormalizedWordVectors) {
            this.useNormalizedWordVectors = useNormalizedWordVectors;
            return this;
        }

        public Builder maxSentenceLength(int maxSentenceLength) {
            this.maxSentenceLength = maxSentenceLength;
            return this;
        }

        public Builder sentencesAlongHeight(boolean sentencesAlongHeight) {
            this.sentencesAlongHeight = sentencesAlongHeight;
            return this;
        }

        public Builder dataSetPreProcessor(DataSetPreProcessor dataSetPreProcessor) {
            this.dataSetPreProcessor = dataSetPreProcessor;
            return this;
        }

        public ConvMFDocumentDataSetIterator build() {
            if (this.wordVectors == null) {
                throw new IllegalStateException("Cannot build ConvMFDocumentDataSetIterator without a WordVectors instance");
            }
            return new ConvMFDocumentDataSetIterator(this);
        }
    }
}

