package cs2110;

public class TreeReconstruction {

    /**
     * Reconstructs a binary tree from its pre-order and in-order traversals.
     * Requires that `preorder` and `inorder` are valid pre-order and in-order traversals
     * (respectively) of the same non-empty binary tree.
     */
    public static ImmutableBinaryTree<Integer> reconstructTree(int[] preorder, int[] inorder) {
        int n = preorder.length;
        return build(preorder, inorder, 0, n, 0, n);
    }

    /**
     * Constructs the subtree described by pre-order traversal segment `preorder[preStart,preEnd)`
     * and in-order traversal segment `inorder[inStart,inEnd)` and returns a reference to its root.
     */
    private static ImmutableBinaryTree<Integer> build(int[] preorder, int[] inorder, int preStart,
        int preEnd, int inStart, int inEnd) {
        assert preEnd - preStart == inEnd - inStart; // sanity check: both ranges should have same length
        if (preEnd - preStart == 0) { // empty range
            return null;
        } else if (preEnd - preStart == 1) { // one-element range
            assert preorder[preStart] == inorder[inStart]; // sanity check: same value
            return new ImmutableBinaryTree<>(preorder[preStart], null, null);
        }

        // Step 1: Use pre-order traversal to identify subtree root
        int rootVal = preorder[preStart];

        // Step 2: Use in-order traversal to determine left subtree size
        int rootIndex = indexOf(inorder, rootVal, inStart);
        int leftSize = rootIndex - inStart;

        // Step 3: Recursively construct left subtree
        ImmutableBinaryTree<Integer> left = build(preorder, inorder, preStart + 1, preStart + 1 + leftSize,
                inStart, rootIndex);

        // Step 4: Recursively construct right subtree
        ImmutableBinaryTree<Integer> right = build(preorder, inorder, preStart + 1 + leftSize, preEnd,
                rootIndex + 1, inEnd);

        // Step 5: Construct subtree to return
        return new ImmutableBinaryTree<>(rootVal, left, right);
    }

    /**
     * Returns the index `i` such that `a[i] == key`. Requires that `key` is present in `a` at
     * an index `>= start`.
     */
    private static int indexOf(int[] a, int key, int start) {
        int i = start;
        while(a[i] != key) {
            i++;
        }
        return i;
    }
}
