/*
 * Decompiled with CFR 0.152.
 */
package com.amazon.randomcutforest;

import com.amazon.randomcutforest.AbstractForestTraversalExecutor;
import com.amazon.randomcutforest.CommonUtils;
import com.amazon.randomcutforest.MultiVisitor;
import com.amazon.randomcutforest.ParallelForestTraversalExecutor;
import com.amazon.randomcutforest.SequentialForestTraversalExecutor;
import com.amazon.randomcutforest.TreeUpdater;
import com.amazon.randomcutforest.Visitor;
import com.amazon.randomcutforest.anomalydetection.AnomalyAttributionVisitor;
import com.amazon.randomcutforest.anomalydetection.AnomalyScoreVisitor;
import com.amazon.randomcutforest.imputation.ImputeVisitor;
import com.amazon.randomcutforest.inspect.NearNeighborVisitor;
import com.amazon.randomcutforest.interpolation.SimpleInterpolationVisitor;
import com.amazon.randomcutforest.returntypes.ConvergingAccumulator;
import com.amazon.randomcutforest.returntypes.DensityOutput;
import com.amazon.randomcutforest.returntypes.DiVector;
import com.amazon.randomcutforest.returntypes.InterpolationMeasure;
import com.amazon.randomcutforest.returntypes.Neighbor;
import com.amazon.randomcutforest.returntypes.OneSidedConvergingDiVectorAccumulator;
import com.amazon.randomcutforest.returntypes.OneSidedConvergingDoubleAccumulator;
import com.amazon.randomcutforest.sampler.SimpleStreamSampler;
import com.amazon.randomcutforest.tree.RandomCutTree;
import com.amazon.randomcutforest.util.ShingleBuilder;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Optional;
import java.util.Random;
import java.util.function.BinaryOperator;
import java.util.function.Function;
import java.util.stream.Collector;

public class RandomCutForest {
    public static final int DEFAULT_SAMPLE_SIZE = 256;
    public static final double DEFAULT_OUTPUT_AFTER_FRACTION = 0.25;
    public static final double DEFAULT_SAMPLE_SIZE_COEFFICIENT_IN_LAMBDA = 10.0;
    public static final int DEFAULT_NUMBER_OF_TREES = 50;
    public static final boolean DEFAULT_STORE_SEQUENCE_INDEXES_ENABLED = false;
    public static final boolean DEFAULT_CENTER_OF_MASS_ENABLED = false;
    public static final boolean DEFAULT_PARALLEL_EXECUTION_ENABLED = false;
    public static final boolean DEFAULT_APPROXIMATE_ANOMALY_SCORE_HIGH_IS_CRITICAL = true;
    public static final double DEFAULT_APPROXIMATE_DYNAMIC_SCORE_PRECISION = 0.1;
    public static final int DEFAULT_APPROXIMATE_DYNAMIC_SCORE_MIN_VALUES_ACCEPTED = 5;
    protected final Random rng;
    protected final int dimensions;
    protected final int sampleSize;
    protected final int outputAfter;
    protected final int numberOfTrees;
    protected final double lambda;
    protected final boolean storeSequenceIndexesEnabled;
    protected final boolean centerOfMassEnabled;
    protected final boolean parallelExecutionEnabled;
    protected final int threadPoolSize;
    protected final AbstractForestTraversalExecutor executor;

    protected RandomCutForest(Builder<?> builder) {
        CommonUtils.checkArgument(((Builder)builder).numberOfTrees > 0, "numberOfTrees must be greater than 0");
        CommonUtils.checkArgument(((Builder)builder).sampleSize > 0, "sampleSize must be greater than 0");
        ((Builder)builder).outputAfter.ifPresent(n -> {
            CommonUtils.checkArgument(n > 0, "outputAfter must be greater than 0");
            CommonUtils.checkArgument(n <= ((Builder)builder).sampleSize, "outputAfter must be smaller or equal to sampleSize");
        });
        CommonUtils.checkArgument(((Builder)builder).dimensions > 0, "dimensions must be greater than 0");
        ((Builder)builder).lambda.ifPresent(lambda -> CommonUtils.checkArgument(lambda >= 0.0, "lambda must be greater than or equal to 0"));
        ((Builder)builder).threadPoolSize.ifPresent(n -> CommonUtils.checkArgument(n > 0, "threadPoolSize must be greater than 0. To disable thread pool, set parallel execution to 'false'."));
        this.numberOfTrees = ((Builder)builder).numberOfTrees;
        this.sampleSize = ((Builder)builder).sampleSize;
        this.outputAfter = ((Builder)builder).outputAfter.orElse((int)((double)this.sampleSize * 0.25));
        this.dimensions = ((Builder)builder).dimensions;
        this.lambda = ((Builder)builder).lambda.orElse(1.0 / (10.0 * (double)this.sampleSize));
        this.storeSequenceIndexesEnabled = ((Builder)builder).storeSequenceIndexesEnabled;
        this.centerOfMassEnabled = ((Builder)builder).centerOfMassEnabled;
        this.parallelExecutionEnabled = ((Builder)builder).parallelExecutionEnabled;
        ArrayList<TreeUpdater> treeUpdaters = new ArrayList<TreeUpdater>(this.numberOfTrees);
        this.rng = ((Builder)builder).randomSeed.map(Random::new).orElseGet(Random::new);
        for (int i = 0; i < this.numberOfTrees; ++i) {
            SimpleStreamSampler sampler = new SimpleStreamSampler(this.sampleSize, this.lambda, this.rng.nextLong());
            RandomCutTree tree = ((RandomCutTree.Builder)((RandomCutTree.Builder)((RandomCutTree.Builder)RandomCutTree.builder().storeSequenceIndexesEnabled(this.storeSequenceIndexesEnabled)).centerOfMassEnabled(this.centerOfMassEnabled)).randomSeed(this.rng.nextLong())).build();
            TreeUpdater updater = new TreeUpdater(sampler, tree);
            treeUpdaters.add(updater);
        }
        if (this.parallelExecutionEnabled) {
            this.threadPoolSize = ((Builder)builder).threadPoolSize.orElse(Runtime.getRuntime().availableProcessors() - 1);
            this.executor = new ParallelForestTraversalExecutor(treeUpdaters, this.threadPoolSize);
        } else {
            this.threadPoolSize = 0;
            this.executor = new SequentialForestTraversalExecutor(treeUpdaters);
        }
    }

    public static Builder builder() {
        return new Builder();
    }

    public static RandomCutForest defaultForest(int dimensions, long randomSeed) {
        return ((Builder)((Builder)RandomCutForest.builder().dimensions(dimensions)).randomSeed(randomSeed)).build();
    }

    public static RandomCutForest defaultForest(int dimensions) {
        return ((Builder)RandomCutForest.builder().dimensions(dimensions)).build();
    }

    public int getNumberOfTrees() {
        return this.numberOfTrees;
    }

    public int getSampleSize() {
        return this.sampleSize;
    }

    public int getOutputAfter() {
        return this.outputAfter;
    }

    public int getDimensions() {
        return this.dimensions;
    }

    public double getLambda() {
        return this.lambda;
    }

    public boolean storeSequenceIndexesEnabled() {
        return this.storeSequenceIndexesEnabled;
    }

    public boolean centerOfMassEnabled() {
        return this.centerOfMassEnabled;
    }

    public boolean parallelExecutionEnabled() {
        return this.parallelExecutionEnabled;
    }

    public int getThreadPoolSize() {
        return this.threadPoolSize;
    }

    public void update(double[] point) {
        CommonUtils.checkNotNull(point, "point must not be null");
        CommonUtils.checkArgument(point.length == this.dimensions, String.format("point.length must equal %d", this.dimensions));
        this.executor.update(point);
    }

    public <R, S> S traverseForest(double[] point, Function<RandomCutTree, Visitor<R>> visitorFactory, BinaryOperator<R> accumulator, Function<R, S> finisher) {
        CommonUtils.checkNotNull(point, "point must not be null");
        CommonUtils.checkArgument(point.length == this.dimensions, String.format("point.length must equal %d", this.dimensions));
        CommonUtils.checkNotNull(visitorFactory, "visitorFactory must not be null");
        CommonUtils.checkNotNull(accumulator, "accumulator must not be null");
        CommonUtils.checkNotNull(finisher, "finisher must not be null");
        return this.executor.traverseForest(point, visitorFactory, accumulator, finisher);
    }

    public <R, S> S traverseForest(double[] point, Function<RandomCutTree, Visitor<R>> visitorFactory, Collector<R, ?, S> collector) {
        CommonUtils.checkNotNull(point, "point must not be null");
        CommonUtils.checkArgument(point.length == this.dimensions, String.format("point.length must equal %d", this.dimensions));
        CommonUtils.checkNotNull(visitorFactory, "visitorFactory must not be null");
        CommonUtils.checkNotNull(collector, "collector must not be null");
        return this.executor.traverseForest(point, visitorFactory, collector);
    }

    public <R, S> S traverseForest(double[] point, Function<RandomCutTree, Visitor<R>> visitorFactory, ConvergingAccumulator<R> accumulator, Function<R, S> finisher) {
        CommonUtils.checkNotNull(point, "point must not be null");
        CommonUtils.checkArgument(point.length == this.dimensions, String.format("point.length must equal %d", this.dimensions));
        CommonUtils.checkNotNull(visitorFactory, "visitorFactory must not be null");
        CommonUtils.checkNotNull(accumulator, "accumulator must not be null");
        CommonUtils.checkNotNull(finisher, "finisher must not be null");
        return this.executor.traverseForest(point, visitorFactory, accumulator, finisher);
    }

    public <R, S> S traverseForestMulti(double[] point, Function<RandomCutTree, MultiVisitor<R>> visitorFactory, BinaryOperator<R> accumulator, Function<R, S> finisher) {
        CommonUtils.checkNotNull(point, "point must not be null");
        CommonUtils.checkArgument(point.length == this.dimensions, String.format("point.length must equal %d", this.dimensions));
        CommonUtils.checkNotNull(visitorFactory, "visitorFactory must not be null");
        CommonUtils.checkNotNull(accumulator, "accumulator must not be null");
        CommonUtils.checkNotNull(finisher, "finisher must not be null");
        return this.executor.traverseForestMulti(point, visitorFactory, accumulator, finisher);
    }

    public <R, S> S traverseForestMulti(double[] point, Function<RandomCutTree, MultiVisitor<R>> visitorFactory, Collector<R, ?, S> collector) {
        CommonUtils.checkNotNull(point, "point must not be null");
        CommonUtils.checkArgument(point.length == this.dimensions, String.format("point.length must equal %d", this.dimensions));
        CommonUtils.checkNotNull(visitorFactory, "visitorFactory must not be null");
        CommonUtils.checkNotNull(collector, "collector must not be null");
        return this.executor.traverseForestMulti(point, visitorFactory, collector);
    }

    public double getAnomalyScore(double[] point) {
        if (!this.isOutputReady()) {
            return 0.0;
        }
        Function visitorFactory = tree -> new AnomalyScoreVisitor(point, tree.getRoot().getMass());
        BinaryOperator accumulator = Double::sum;
        Function<Double, Double> finisher = sum -> sum / (double)this.numberOfTrees;
        return this.traverseForest(point, visitorFactory, accumulator, finisher);
    }

    public double getApproximateAnomalyScore(double[] point) {
        if (!this.isOutputReady()) {
            return 0.0;
        }
        Function visitorFactory = tree -> new AnomalyScoreVisitor(point, tree.getRoot().getMass());
        OneSidedConvergingDoubleAccumulator accumulator = new OneSidedConvergingDoubleAccumulator(true, 0.1, 5, this.numberOfTrees);
        Function<Double, Double> finisher = x -> x / (double)accumulator.getValuesAccepted();
        return this.traverseForest(point, visitorFactory, accumulator, finisher);
    }

    public DiVector getAnomalyAttribution(double[] point) {
        if (!this.isOutputReady()) {
            return new DiVector(this.dimensions);
        }
        Function visitorFactory = tree -> new AnomalyAttributionVisitor(point, tree.getRoot().getMass());
        BinaryOperator accumulator = DiVector::addToLeft;
        Function<DiVector, DiVector> finisher = x -> x.scale(1.0 / (double)this.numberOfTrees);
        return this.traverseForest(point, visitorFactory, accumulator, finisher);
    }

    public DiVector getApproximateAnomalyAttribution(double[] point) {
        if (!this.isOutputReady()) {
            return new DiVector(this.dimensions);
        }
        Function visitorFactory = tree -> new AnomalyAttributionVisitor(point, tree.getRoot().getMass());
        OneSidedConvergingDiVectorAccumulator accumulator = new OneSidedConvergingDiVectorAccumulator(this.dimensions, true, 0.1, 5, this.numberOfTrees);
        Function<DiVector, DiVector> finisher = vector -> vector.scale(1.0 / (double)accumulator.getValuesAccepted());
        return this.traverseForest(point, visitorFactory, accumulator, finisher);
    }

    public DensityOutput getSimpleDensity(double[] point) {
        if (!this.samplersFull()) {
            return new DensityOutput(this.dimensions, this.sampleSize);
        }
        Function visitorFactory = tree -> new SimpleInterpolationVisitor(point, this.sampleSize, 1.0, this.centerOfMassEnabled);
        Collector<InterpolationMeasure, InterpolationMeasure, InterpolationMeasure> collector = InterpolationMeasure.collector(this.dimensions, this.sampleSize, this.numberOfTrees);
        return new DensityOutput(this.traverseForest(point, visitorFactory, collector));
    }

    public double[] imputeMissingValues(double[] point, int numberOfMissingValues, int[] missingIndexes) {
        CommonUtils.checkArgument(numberOfMissingValues >= 0, "numberOfMissingValues must be greater than or equal to 0");
        CommonUtils.checkNotNull(point, "point must not be null");
        if (numberOfMissingValues == 0) {
            return Arrays.copyOf(point, point.length);
        }
        CommonUtils.checkNotNull(missingIndexes, "missingIndexes must not be null");
        CommonUtils.checkArgument(numberOfMissingValues <= missingIndexes.length, "numberOfMissingValues must be less than or equal to missingIndexes.length");
        if (!this.isOutputReady()) {
            return new double[this.dimensions];
        }
        Function visitorFactory = tree -> new ImputeVisitor(point, numberOfMissingValues, missingIndexes);
        if (numberOfMissingValues == 1) {
            Collector<double[], ArrayList, ArrayList> collector = Collector.of(ArrayList::new, (list, array) -> list.add(array[missingIndexes[0]]), (left, right) -> {
                left.addAll(right);
                return left;
            }, list -> {
                list.sort(Comparator.comparing(Double::doubleValue));
                return list;
            }, new Collector.Characteristics[0]);
            ArrayList imputedValues = this.traverseForestMulti(point, visitorFactory, collector);
            double[] returnPoint = Arrays.copyOf(point, this.dimensions);
            returnPoint[missingIndexes[0]] = (Double)imputedValues.get(this.numberOfTrees / 2);
            return returnPoint;
        }
        Collector<double[], ArrayList, ArrayList> collector = Collector.of(ArrayList::new, ArrayList::add, (left, right) -> {
            left.addAll(right);
            return left;
        }, list -> {
            list.sort(Comparator.comparing(this::getAnomalyScore));
            return list;
        }, new Collector.Characteristics[0]);
        ArrayList imputedPoints = this.traverseForestMulti(point, visitorFactory, collector);
        return (double[])imputedPoints.get(this.numberOfTrees / 4);
    }

    public double[] extrapolateBasic(double[] point, int horizon, int blockSize, boolean cyclic, int shingleIndex) {
        CommonUtils.checkArgument(0 < blockSize && blockSize < this.dimensions, "blockSize must be between 0 and dimensions (exclusive)");
        CommonUtils.checkArgument(this.dimensions % blockSize == 0, "dimensions must be evenly divisible by blockSize");
        CommonUtils.checkArgument(0 <= shingleIndex && shingleIndex < this.dimensions / blockSize, "shingleIndex must be between 0 (inclusive) and dimensions / blockSize");
        double[] result = new double[blockSize * horizon];
        int[] missingIndexes = new int[blockSize];
        double[] queryPoint = Arrays.copyOf(point, this.dimensions);
        if (cyclic) {
            this.extrapolateBasicCyclic(result, horizon, blockSize, shingleIndex, queryPoint, missingIndexes);
        } else {
            this.extrapolateBasicSliding(result, horizon, blockSize, queryPoint, missingIndexes);
        }
        return result;
    }

    public double[] extrapolateBasic(double[] point, int horizon, int blockSize, boolean cyclic) {
        return this.extrapolateBasic(point, horizon, blockSize, cyclic, 0);
    }

    public double[] extrapolateBasic(ShingleBuilder builder, int horizon) {
        return this.extrapolateBasic(builder.getShingle(), horizon, builder.getInputPointSize(), builder.isCyclic(), builder.getShingleIndex());
    }

    void extrapolateBasicSliding(double[] result, int horizon, int blockSize, double[] queryPoint, int[] missingIndexes) {
        int resultIndex = 0;
        Arrays.fill(missingIndexes, 0);
        for (int y = 0; y < blockSize; ++y) {
            missingIndexes[y] = this.dimensions - blockSize + y;
        }
        for (int k = 0; k < horizon; ++k) {
            System.arraycopy(queryPoint, blockSize, queryPoint, 0, this.dimensions - blockSize);
            double[] imputedPoint = this.imputeMissingValues(queryPoint, blockSize, missingIndexes);
            for (int y = 0; y < blockSize; ++y) {
                int n = resultIndex++;
                double d = imputedPoint[this.dimensions - blockSize + y];
                queryPoint[this.dimensions - blockSize + y] = d;
                result[n] = d;
            }
        }
    }

    void extrapolateBasicCyclic(double[] result, int horizon, int blockSize, int shingleIndex, double[] queryPoint, int[] missingIndexes) {
        int resultIndex = 0;
        int currentPosition = shingleIndex;
        Arrays.fill(missingIndexes, 0);
        for (int k = 0; k < horizon; ++k) {
            for (int y = 0; y < blockSize; ++y) {
                missingIndexes[y] = (currentPosition + y) % this.dimensions;
            }
            double[] imputedPoint = this.imputeMissingValues(queryPoint, blockSize, missingIndexes);
            for (int y = 0; y < blockSize; ++y) {
                int n = resultIndex++;
                double d = imputedPoint[(currentPosition + y) % this.dimensions];
                queryPoint[(currentPosition + y) % this.dimensions] = d;
                result[n] = d;
            }
            currentPosition = (currentPosition + blockSize) % this.dimensions;
        }
    }

    public List<Neighbor> getNearNeighborsInSample(double[] point, double distanceThreshold) {
        CommonUtils.checkNotNull(point, "point must not be null");
        CommonUtils.checkArgument(distanceThreshold > 0.0, "distanceThreshold must be greater than 0");
        if (!this.isOutputReady()) {
            return Collections.emptyList();
        }
        Function visitorFactory = tree -> new NearNeighborVisitor(point, distanceThreshold);
        return this.traverseForest(point, visitorFactory, Neighbor.collector());
    }

    public List<Neighbor> getNearNeighborsInSample(double[] point) {
        return this.getNearNeighborsInSample(point, Double.POSITIVE_INFINITY);
    }

    public boolean isOutputReady() {
        return this.executor.getTotalUpdates() >= (long)this.outputAfter;
    }

    public boolean samplersFull() {
        return this.executor.getTotalUpdates() >= (long)this.sampleSize;
    }

    public long getTotalUpdates() {
        return this.executor.getTotalUpdates();
    }

    public static class Builder<T extends Builder<T>> {
        private int dimensions;
        private int sampleSize = 256;
        private Optional<Integer> outputAfter = Optional.empty();
        private int numberOfTrees = 50;
        private Optional<Double> lambda = Optional.empty();
        private Optional<Long> randomSeed = Optional.empty();
        private boolean storeSequenceIndexesEnabled = false;
        private boolean centerOfMassEnabled = false;
        private boolean parallelExecutionEnabled = false;
        private Optional<Integer> threadPoolSize = Optional.empty();

        public T dimensions(int dimensions) {
            this.dimensions = dimensions;
            return (T)this;
        }

        public T sampleSize(int sampleSize) {
            this.sampleSize = sampleSize;
            return (T)this;
        }

        public T outputAfter(int outputAfter) {
            this.outputAfter = Optional.of(outputAfter);
            return (T)this;
        }

        public T numberOfTrees(int numberOfTrees) {
            this.numberOfTrees = numberOfTrees;
            return (T)this;
        }

        public T lambda(double lambda) {
            this.lambda = Optional.of(lambda);
            return (T)this;
        }

        public T randomSeed(long randomSeed) {
            this.randomSeed = Optional.of(randomSeed);
            return (T)this;
        }

        public T storeSequenceIndexesEnabled(boolean storeSequenceIndexesEnabled) {
            this.storeSequenceIndexesEnabled = storeSequenceIndexesEnabled;
            return (T)this;
        }

        public T centerOfMassEnabled(boolean centerOfMassEnabled) {
            this.centerOfMassEnabled = centerOfMassEnabled;
            return (T)this;
        }

        public T parallelExecutionEnabled(boolean parallelExecutionEnabled) {
            this.parallelExecutionEnabled = parallelExecutionEnabled;
            return (T)this;
        }

        public T threadPoolSize(int threadPoolSize) {
            this.threadPoolSize = Optional.of(threadPoolSize);
            return (T)this;
        }

        public RandomCutForest build() {
            return new RandomCutForest(this);
        }
    }
}

