/*
 * Decompiled with CFR 0.152.
 */
package com.amazon.randomcutforest.tree;

import com.amazon.randomcutforest.CommonUtils;
import com.amazon.randomcutforest.MultiVisitor;
import com.amazon.randomcutforest.Visitor;
import com.amazon.randomcutforest.sampler.WeightedPoint;
import com.amazon.randomcutforest.tree.BoundingBox;
import com.amazon.randomcutforest.tree.Cut;
import com.amazon.randomcutforest.tree.Node;
import java.util.Arrays;
import java.util.Random;

public class RandomCutTree {
    public static final boolean DEFAULT_STORE_SEQUENCE_INDEXES_ENABLED = false;
    public static final boolean DEFAULT_CENTER_OF_MASS_ENABLED = false;
    private final boolean storeSequenceIndexesEnabled;
    private final boolean centerOfMassEnabled;
    private final Random random;
    protected Node root;

    protected RandomCutTree(Builder<?> builder) {
        this.storeSequenceIndexesEnabled = ((Builder)builder).storeSequenceIndexesEnabled;
        this.centerOfMassEnabled = ((Builder)builder).centerOfMassEnabled;
        this.random = ((Builder)builder).random != null ? ((Builder)builder).random : new Random();
    }

    public static Builder builder() {
        return new Builder();
    }

    public static RandomCutTree defaultTree(long randomSeed) {
        return ((Builder)RandomCutTree.builder().randomSeed(randomSeed)).build();
    }

    public static RandomCutTree defaultTree() {
        return RandomCutTree.builder().build();
    }

    public boolean centerOfMassEnabled() {
        return this.centerOfMassEnabled;
    }

    public boolean storeSequenceIndexesEnabled() {
        return this.storeSequenceIndexesEnabled;
    }

    static Cut randomCut(Random random, BoundingBox box) {
        double rangeSum = box.getRangeSum();
        CommonUtils.checkArgument(rangeSum > 0.0, "box.getRangeSum() must be greater than 0");
        double breakPoint = random.nextDouble() * rangeSum;
        for (int i = 0; i < box.getDimensions(); ++i) {
            double range = box.getRange(i);
            if (breakPoint <= range) {
                double cutValue = box.getMinValue(i) + breakPoint;
                if (cutValue == box.getMaxValue(i) && box.getMinValue(i) < box.getMaxValue(i)) {
                    cutValue = Math.nextAfter(box.getMaxValue(i), box.getMinValue(i));
                }
                return new Cut(i, cutValue);
            }
            breakPoint -= range;
        }
        throw new IllegalStateException("The break point did not lie inside the expected range");
    }

    static void replaceNode(Node oldNode, Node newNode) {
        Node parent = oldNode.getParent();
        if (parent != null) {
            if (parent.getLeftChild() == oldNode) {
                parent.setLeftChild(newNode);
            } else {
                parent.setRightChild(newNode);
            }
        }
        newNode.setParent(parent);
    }

    static Node getSibling(Node node) {
        CommonUtils.checkNotNull(node.getParent(), "node parent must not be null");
        Node parent = node.getParent();
        if (parent.getLeftChild() == node) {
            return parent.getRightChild();
        }
        if (parent.getRightChild() == node) {
            return parent.getLeftChild();
        }
        throw new IllegalArgumentException("node parent does not link back to node");
    }

    public void deletePoint(WeightedPoint weightedPoint) {
        CommonUtils.checkState(this.root != null, "root must not be null");
        this.deletePoint(this.root, weightedPoint.getPoint(), weightedPoint.getSequenceIndex());
    }

    void deletePoint(Node node, double[] point, long sequenceIndex) {
        if (node.isLeaf()) {
            if (!node.leafPointEquals(point)) {
                throw new IllegalStateException(Arrays.toString(point) + " " + Arrays.toString(node.getLeafPoint()) + " " + Arrays.equals(node.getLeafPoint(), point) + " Inconsistency in trees in delete step.");
            }
            if (this.storeSequenceIndexesEnabled && !node.getSequenceIndexes().contains(sequenceIndex)) {
                throw new IllegalStateException("Error in sequence index. Inconsistency in trees in delete step.");
            }
            if (node.getMass() > 1) {
                node.decrementMass();
                if (this.storeSequenceIndexesEnabled) {
                    node.deleteSequenceIndex(sequenceIndex);
                }
                return;
            }
            Node parent = node.getParent();
            if (parent == null) {
                this.root = null;
                return;
            }
            Node grandParent = parent.getParent();
            if (grandParent == null) {
                this.root = RandomCutTree.getSibling(node);
                this.root.setParent(null);
            } else {
                Node sibling = RandomCutTree.getSibling(node);
                RandomCutTree.replaceNode(parent, sibling);
            }
            return;
        }
        if (Node.isLeftOf(point, node)) {
            this.deletePoint(node.getLeftChild(), point, sequenceIndex);
        } else {
            this.deletePoint(node.getRightChild(), point, sequenceIndex);
        }
        BoundingBox leftBox = node.getLeftChild().getBoundingBox();
        BoundingBox rightBox = node.getRightChild().getBoundingBox();
        node.setBoundingBox(leftBox.getMergedBox(rightBox));
        node.decrementMass();
        if (this.centerOfMassEnabled) {
            node.subtractFromPointSum(point);
        }
    }

    public void addPoint(WeightedPoint weightedPoint) {
        if (this.root == null) {
            this.root = this.newLeafNode(weightedPoint.getPoint(), weightedPoint.getSequenceIndex());
        } else {
            this.addPoint(this.root, weightedPoint.getPoint(), weightedPoint.getSequenceIndex());
        }
    }

    private void addPoint(Node node, double[] point, long sequenceIndex) {
        if (node.isLeaf() && node.leafPointEquals(point)) {
            node.incrementMass();
            if (this.storeSequenceIndexesEnabled) {
                node.addSequenceIndex(sequenceIndex);
            }
            return;
        }
        BoundingBox existingBox = node.getBoundingBox();
        BoundingBox mergedBox = existingBox.getMergedBox(point);
        if (!existingBox.contains(point)) {
            Cut cut = RandomCutTree.randomCut(this.random, mergedBox);
            int splitDimension = cut.getDimension();
            double splitValue = cut.getValue();
            double minValue = existingBox.getMinValue(splitDimension);
            double maxValue = existingBox.getMaxValue(splitDimension);
            if (minValue > splitValue || maxValue <= splitValue) {
                Node mergedNode;
                Node leaf = this.newLeafNode(point, sequenceIndex);
                Node node2 = mergedNode = minValue > splitValue ? this.newNode(leaf, node, cut, mergedBox) : this.newNode(node, leaf, cut, mergedBox);
                if (node.getParent() == null) {
                    this.root = mergedNode;
                } else {
                    RandomCutTree.replaceNode(node, mergedNode);
                }
                leaf.setParent(mergedNode);
                node.setParent(mergedNode);
                return;
            }
        }
        if (Node.isLeftOf(point, node)) {
            this.addPoint(node.getLeftChild(), point, sequenceIndex);
        } else {
            this.addPoint(node.getRightChild(), point, sequenceIndex);
        }
        node.setBoundingBox(mergedBox);
        node.incrementMass();
        if (this.centerOfMassEnabled) {
            node.addToPointSum(point);
        }
    }

    public <R> R traverseTree(double[] point, Visitor<R> visitor) {
        CommonUtils.checkState(this.root != null, "this tree doesn't contain any nodes");
        this.traversePathToLeafAndVisitNodes(point, visitor, this.root, 0);
        return visitor.getResult();
    }

    private <R> void traversePathToLeafAndVisitNodes(double[] point, Visitor<R> visitor, Node currentNode, int depthOfNode) {
        if (currentNode.isLeaf()) {
            visitor.acceptLeaf(currentNode, depthOfNode);
        } else {
            Node childNode = Node.isLeftOf(point, currentNode) ? currentNode.getLeftChild() : currentNode.getRightChild();
            this.traversePathToLeafAndVisitNodes(point, visitor, childNode, depthOfNode + 1);
            visitor.accept(currentNode, depthOfNode);
        }
    }

    public <R> R traverseTreeMulti(double[] point, MultiVisitor<R> visitor) {
        CommonUtils.checkNotNull(point, "point must not be null");
        CommonUtils.checkNotNull(visitor, "visitor must not be null");
        CommonUtils.checkState(this.root != null, "this tree doesn't contain any nodes");
        this.traverseTreeMulti(point, visitor, this.root, 0);
        return visitor.getResult();
    }

    private <R> void traverseTreeMulti(double[] point, MultiVisitor<R> visitor, Node currentNode, int depthOfNode) {
        if (currentNode.isLeaf()) {
            visitor.acceptLeaf(currentNode, depthOfNode);
        } else if (visitor.trigger(currentNode)) {
            this.traverseTreeMulti(point, visitor, currentNode.getLeftChild(), depthOfNode + 1);
            MultiVisitor<R> newVisitor = visitor.newCopy();
            this.traverseTreeMulti(point, newVisitor, currentNode.getRightChild(), depthOfNode + 1);
            visitor.combine(newVisitor);
            visitor.accept(currentNode, depthOfNode);
        } else {
            Node childNode = Node.isLeftOf(point, currentNode) ? currentNode.getLeftChild() : currentNode.getRightChild();
            this.traverseTreeMulti(point, visitor, childNode, depthOfNode + 1);
            visitor.accept(currentNode, depthOfNode);
        }
    }

    private Node newLeafNode(double[] point, long sequenceIndex) {
        Node node = new Node(point);
        node.setMass(1);
        if (this.storeSequenceIndexesEnabled) {
            node.addSequenceIndex(sequenceIndex);
        }
        return node;
    }

    private Node newNode(Node leftChild, Node rightChild, Cut cut, BoundingBox box) {
        Node node = new Node(leftChild, rightChild, cut, box, this.centerOfMassEnabled);
        if (leftChild != null) {
            node.addMass(leftChild.getMass());
            if (this.centerOfMassEnabled) {
                node.addToPointSum(leftChild.getPointSum());
            }
        }
        if (rightChild != null) {
            node.addMass(rightChild.getMass());
            if (this.centerOfMassEnabled) {
                node.addToPointSum(rightChild.getPointSum());
            }
        }
        return node;
    }

    public Node getRoot() {
        return this.root;
    }

    public static class Builder<T extends Builder<T>> {
        private boolean storeSequenceIndexesEnabled = false;
        private boolean centerOfMassEnabled = false;
        private Random random = null;

        public T storeSequenceIndexesEnabled(boolean storeSequenceIndexesEnabled) {
            this.storeSequenceIndexesEnabled = storeSequenceIndexesEnabled;
            return (T)this;
        }

        public T centerOfMassEnabled(boolean centerOfMassEnabled) {
            this.centerOfMassEnabled = centerOfMassEnabled;
            return (T)this;
        }

        public T random(Random random) {
            this.random = random;
            return (T)this;
        }

        public T randomSeed(long randomSeed) {
            this.random = new Random(randomSeed);
            return (T)this;
        }

        public RandomCutTree build() {
            return new RandomCutTree(this);
        }
    }
}

