/*
 * Decompiled with CFR 0.152.
 */
package com.amazon.randomcutforest;

import com.amazon.randomcutforest.AbstractForestTraversalExecutor;
import com.amazon.randomcutforest.MultiVisitor;
import com.amazon.randomcutforest.TreeUpdater;
import com.amazon.randomcutforest.Visitor;
import com.amazon.randomcutforest.returntypes.ConvergingAccumulator;
import com.amazon.randomcutforest.tree.RandomCutTree;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.function.BinaryOperator;
import java.util.function.Function;
import java.util.stream.Collector;
import java.util.stream.Collectors;

public class ParallelForestTraversalExecutor
extends AbstractForestTraversalExecutor {
    private ForkJoinPool forkJoinPool;
    private final int threadPoolSize;

    public ParallelForestTraversalExecutor(ArrayList<TreeUpdater> treeUpdaters, int threadPoolSize) {
        super(treeUpdaters);
        this.threadPoolSize = threadPoolSize;
        this.forkJoinPool = new ForkJoinPool(threadPoolSize);
    }

    @Override
    protected void update(double[] pointCopy, long sequenceIndex) {
        this.submitAndJoin(() -> {
            this.treeUpdaters.parallelStream().forEach(updater -> updater.update(pointCopy, sequenceIndex));
            return null;
        });
    }

    @Override
    public <R, S> S traverseForest(double[] point, Function<RandomCutTree, Visitor<R>> visitorFactory, BinaryOperator<R> accumulator, Function<R, S> finisher) {
        return (S)this.submitAndJoin(() -> this.treeUpdaters.parallelStream().map(TreeUpdater::getTree).map(tree -> {
            Visitor visitor = (Visitor)visitorFactory.apply((RandomCutTree)tree);
            return tree.traverseTree(point, visitor);
        }).reduce(accumulator).map(finisher)).orElseThrow(() -> new IllegalStateException("accumulator returned an empty result"));
    }

    @Override
    public <R, S> S traverseForest(double[] point, Function<RandomCutTree, Visitor<R>> visitorFactory, Collector<R, ?, S> collector) {
        return (S)this.submitAndJoin(() -> this.treeUpdaters.parallelStream().map(TreeUpdater::getTree).map(tree -> {
            Visitor visitor = (Visitor)visitorFactory.apply((RandomCutTree)tree);
            return tree.traverseTree(point, visitor);
        }).collect(collector));
    }

    @Override
    public <R, S> S traverseForest(double[] point, Function<RandomCutTree, Visitor<R>> visitorFactory, ConvergingAccumulator<R> accumulator, Function<R, S> finisher) {
        for (int i = 0; i < this.treeUpdaters.size(); i += this.threadPoolSize) {
            int start = i;
            int end = Math.min(start + this.threadPoolSize, this.treeUpdaters.size());
            List results = this.submitAndJoin(() -> this.treeUpdaters.subList(start, end).parallelStream().map(TreeUpdater::getTree).map(tree -> {
                Visitor visitor = (Visitor)visitorFactory.apply((RandomCutTree)tree);
                return tree.traverseTree(point, visitor);
            }).collect(Collectors.toList()));
            results.forEach(accumulator::accept);
            if (accumulator.isConverged()) break;
        }
        return finisher.apply(accumulator.getAccumulatedValue());
    }

    @Override
    public <R, S> S traverseForestMulti(double[] point, Function<RandomCutTree, MultiVisitor<R>> visitorFactory, BinaryOperator<R> accumulator, Function<R, S> finisher) {
        return (S)this.submitAndJoin(() -> this.treeUpdaters.parallelStream().map(TreeUpdater::getTree).map(tree -> {
            MultiVisitor visitor = (MultiVisitor)visitorFactory.apply((RandomCutTree)tree);
            return tree.traverseTreeMulti(point, visitor);
        }).reduce(accumulator).map(finisher)).orElseThrow(() -> new IllegalStateException("accumulator returned an empty result"));
    }

    @Override
    public <R, S> S traverseForestMulti(double[] point, Function<RandomCutTree, MultiVisitor<R>> visitorFactory, Collector<R, ?, S> collector) {
        return (S)this.submitAndJoin(() -> this.treeUpdaters.parallelStream().map(TreeUpdater::getTree).map(tree -> {
            MultiVisitor visitor = (MultiVisitor)visitorFactory.apply((RandomCutTree)tree);
            return tree.traverseTreeMulti(point, visitor);
        }).collect(collector));
    }

    private <T> T submitAndJoin(Callable<T> callable) {
        if (this.forkJoinPool == null) {
            this.forkJoinPool = new ForkJoinPool(this.threadPoolSize);
        }
        return (T)((ForkJoinTask)this.forkJoinPool.submit((Callable)callable)).join();
    }
}

