Efficiently Implement a Basic Segment Tree

Photo by Kevin Ku on Unsplash

Efficiently Implement a Basic Segment Tree

·

7 min read

A segment tree is a binary tree structure that is used to store information about intervals, or segments. It is one of the most powerful tree data structures for solving problems related to interval queries and modifications. For example, querying the sum/max/min value of any given range after modifying one or more values in the array. A segment tree allows querying and modifying in O(log N) time, where N means the length of the input array.

Currently, numerous articles provide detailed explanations of the principles behind segment trees. However, the code implementations in many of these articles may not be easily understood, especially by beginners in programming. The objective of this article is to assist individuals in gaining a better understanding of segment trees and their applications by providing a step-by-step implementation.

In this article, we will discuss how to implement a basic segment tree. Then we talk about how to evolve the implementation using an array as the underlying storage for better performance in practice.

How to implement a basic segment tree

A basic segment tree supports range queries and modifying a single value.

Now let's implement a basic segment tree that supports querying arbitrary range sum, based on a physical binary tree structure.

The tree node is defined as:

public class SegmentTreeNode {
    int start;
    int end;
    int sum;
    SegmentTreeNode left;
    SegmentTreeNode right;

    SegmentTreeNode(int start, int end, int sum) {
        this.start = start;
        this.end = end;
        this.sum = sum;
        this.left = null;
        this.right = null;
    }
}

Build the tree

In a segment tree, the leaf node represents a single value (i.e. node.start == node.end), while the internal node represents a range (i.e. node.start < node.end). Each internal node splits its range into two non-overlapping halves which are represented by its left & right children nodes.

public class SegmentTree {
    private final SegmentTreeNode root;

    public SegmentTree(int[] nums) {
        if (nums == null || nums.length == 0) {
            throw new IllegalArgumentException("input array should not be empty");
        }
        this.root = buildTree(0, nums.length - 1, nums);
    }

    private SegmentTreeNode buildTree(int start, int end, int[] nums) {
        if (start == end) {
            return new SegmentTreeNode(start, end, nums[start]);
        }

        SegmentTreeNode node = new SegmentTreeNode(start, end, 0);
        int mid = start + (end - start) / 2;

        node.left = buildTree(start, mid, nums);
        node.right = buildTree(mid + 1, end, nums);

        node.sum = node.left.sum + node.right.sum;

        return node;
    }
}

As we can see, to build the tree, we need to iterate the entire input array and build a leaf node to represent every single value. Since each internal node has two children (left & right), the total number of tree nodes is O(2 * N) ~= O(N). So the time complexity of building the tree is O(N).

Range query

When we query the sum for a given range, we start from the root node. If the current node's range (i.e. [node.start, node.end]) is entirely included within the input range, then we directly return the value of the current node. Otherwise, we recursively query the results from its children nodes and aggregate the results as the final output.

public class SegmentTree {
    // ... code for building the tree

    public int queryRange(int start, int end) {
        if (start > end) {
            throw new IllegalArgumentException("start should not be greater than end");
        }
        return query(this.root, start, end);
    }

    private int query(SegmentTreeNode node, int start, int end) {
        if (node == null) {
            return 0;
        }
        if (start <= node.start && node.end <= end) {
            return node.sum;
        }

        int mid = node.start + (node.end - node.start) / 2;

        int sum = 0;
        if (start <= mid) {
            sum += query(node.left, start, end);
        }
        if (mid < end) {
            sum += query(node.right, start, end);
        }

        return sum;
    }
}

As we can see, to make a range query, we start from the root node and compare the range represented by the node with the query range. We return immediately once we find the node range is fully included within the query range. Otherwise, we dive down until we hit the leaf node. So the time complexity of the range query is O(height). The tree is balanced since we divide the range of each internal node into two halves, represented by the left and right subtrees, respectively. Then the time complexity is O(height) ~= O(log N).

Modify a single value

When we modify the value for a given index, we start from the root node to find the corresponding leaf node which represents the index. After we modified the value, we need to update the value of each internal node alongside the path from the root node to the leaf node.

public class SegmentTree {
    // ... code for building the tree & querying range value

    public void modify(int index, int newVal) {
        if (index < root.start || index > root.end) {
            throw new IllegalArgumentException("input index is out of bound");
        }
        modify(this.root, index, newVal);
    }

    private void modify(SegmentTreeNode node, int index, int val) {
        if (node == null) {
            return;
        }
        if (index == node.start && index == node.end) {
            node.sum = val;
            return;
        }

        int mid = node.start + (node.end - node.start) / 2;
        if (index <= mid) {
            modify(node.left, index, val);
        } else {
            modify(node.right, index, val);
        }
        node.sum = node.left.sum + node.right.sum;
    }
}

As we can see, to modify a value, we start from the root node and choose a direction (left or right) each time until we hit the leaf node. So the time complexity of modifying a single value is O(height) ~= O(log N).

Evolve the implementation using an array as the underlying storage

In the above section, we implement a basic segment tree using a physical binary tree structure. According to the definition of the segment tree, each internal node has two children (left & right). So we can use an array as the underlying storage: for each internal node whose index is idx (the array is 0-indexed), its left & right children are 2 * idx + 1 and 2 * idx + 2 respectively.

public class SegmentTree {
    private final int[] tree;
    private final int upperIdx;

    public SegmentTree(int[] nums) {
        if (nums == null || nums.length == 0) {
            throw new IllegalArgumentException("input array should not be empty");
        }

        // make enough rooms for the internal nodes.
        this.tree = new int[nums.length * 4];

        // store the upper index which indicates the end of the range.
        this.upperIdx = nums.length - 1;

        buildTree(0, 0, upperIdx, nums);
    }

    public int queryRange(int start, int end) {
        if (start > end) {
            throw new IllegalArgumentException("start should not be greater than end");
        }
        return query(0, 0, upperIdx, start, end);
    }

    public void modify(int index, int newVal) {
        if (index < 0 || index > upperIdx) {
            throw new IllegalArgumentException("input index is out of bound");
        }
        modify(0, 0, upperIdx, index, newVal);
    }

    private void buildTree(int node, int start, int end, int[] nums) {
        if (start == end) {
            tree[node] = nums[start];
            return;
        }

        int leftChild = 2 * node + 1;
        int rightChild = 2 * node + 2;
        int mid = start + (end - start) / 2;

        buildTree(leftChild, start, mid, nums);
        buildTree(rightChild, mid + 1, end, nums);

        tree[node] = tree[leftChild] + tree[rightChild];
    }

    private int query(int node, int nodeStart, int nodeEnd, int start, int end) {
        if (end < nodeStart || nodeEnd < start) {
            return 0;
        }
        if (start <= nodeStart && nodeEnd <= end) {
            return tree[node];
        }

        int nodeMid = nodeStart + (nodeEnd - nodeStart) / 2;

        int sum = 0;
        if (start <= nodeMid) {
            sum += query(2 * node + 1, nodeStart, nodeMid, start, end);
        }
        if (nodeMid < end) {
            sum += query(2 * node + 2, nodeMid + 1, nodeEnd, start, end);
        }

        return sum;
    }

    private void modify(int node, int nodeStart, int nodeEnd, int idx, int val) {
        if (idx == nodeStart && idx == nodeEnd) {
            tree[node] = val;
            return;
        }

        int leftChild = 2 * node + 1;
        int rightChild = 2 * node + 2;
        int nodeMid = nodeStart + (nodeEnd - nodeStart) / 2;

        if (idx <= nodeMid) {
            modify(leftChild, nodeStart, nodeMid, idx, val);
        } else {
            modify(rightChild, nodeMid + 1, nodeEnd, idx, val);
        }
        tree[node] = tree[leftChild] + tree[rightChild];
    }
}

Conclusion

In this article, we discussed how to build a basic segment tree, make a range query and update a single value. We also implemented the segment by using a physical binary tree structure or by using a bounded array.

According to the implementation, it is obvious that the time complexity is:

  • Build the tree: O(N)

  • Make a range query: O(log N)

  • Update a single value: O(log N)

In next article, we will discuss how to further evolve the segment tree to make it support updating a range values in O(log N) time.