Utilize Segment Tree to Solve Problem

·

5 min read

If you don't know what is a segment tree or how to implement a segment tree that supports range updates, please refer to these two articles: Efficiently Implement a Basic Segment Tree & Segment Tree with Lazy Propagation.

Put simply, a lazy segment tree (or segment tree with lazy propagation) can make a range query/update in O(log N) time. In this article, we will discuss how we can use a lazy segment tree to solve a practical problem.

Problem

You are given two 0-indexed int arrays bits and nums (0 <= bits[i] <= 1, -10^9 <= nums[i] <= 10^9, len(bits) == len(nums)). Now design a data structure that supports the following three operations:

  • flip(int l, int r): Flip the values from 0 to 1 and from 1 to 0 in bits from index l to index r. Both l and r are 0-indexed and inclusive.

  • update(int p): For every index 0 <= i < len(nums), set nums[i] = nums[i] + bits[i] * p.

  • query(): Return the sum of the elements in nums.

You should implement the data structure that supports the above functions in an efficient way (the original problem can be found on leetcode).

Solution

Intuition

The data structure supports two write operations (flip & update) and one read operation (query):

  • Regarding the flip operation, it can flip a range of bits in the bits array, which means that a lazy segment tree may be applicable in this place.

  • Regarding the update operation, it can add an integer p to nums[i] for all 0 <= i < len(nums) if bits[i] == 1, which is equivalent to adding p * cardinality(bits) to the sum(nums). Here, cardinality(bits) returns the number of ones in bits array. The good news is that the segment tree provides us with an efficient range query that can be used to implement the cardinality operation. In fact, cardinality is a special case of range query, as its range is always the full range [0, len(bits) - 1].

  • Regarding the query operation, it just returns the current sum of nums. This operation should be trivial since we can cache the sum of nums after each update operation and return it directly when the query operation is called.

Implementation

public class Solution {
    private final SegmentTree bitsTree;
    private int sum;

    public Solution(int[] bits, int[] nums) {
        this.bitsTree = new SegmentTree(bits);
        this.sum = getSum(nums);
    }

    public void flip(int l, int r) {
        bitsTree.flip(l, r);
    }

    public void update(int p) {
        sum += p * bitsTree.cardinality();
    }

    public int query() {
        return sum;
    }

    private int getSum(int[] nums) {
        int sum = 0;
        for (int num : nums) {
            sum += num;
        }
        return sum;
    }
}

class SegmentTree {
    private final int upperIdx;
    private final int[] tree;
    private final int[] lazy;

    SegmentTree(int[] bits) {
        this.upperIdx = bits.length - 1;
        this.tree = new int[bits.length * 4];
        this.lazy = new int[bits.length * 4];
        this.buildTree(0, 0, upperIdx, bits);
    }

    void flip(int left, int right) {
        flip(0, 0, upperIdx, left, right);
    }

    int cardinality() {
        // tree[0] stores the number of ones for the entire array.
        return tree[0];
    }

    private void buildTree(int node, int start, int end, int[] bits) {
        if (start == end) {
            tree[node] = bits[start];
            return;
        }

        int leftChild = 2 * node + 1;
        int rightChild = 2 * node + 2;
        int mid = start + (end - start) / 2;

        buildTree(leftChild, start, mid, bits);
        buildTree(rightChild, mid + 1, end, bits);

        // tree[node] stores the number of ones in range [start, end].
        tree[node] = tree[leftChild] + tree[rightChild];
    }

    private void flip(int node, int start, int end, int left, int right) {
        // non-zero means there are pending flips.
        if (lazy[node] != 0) {
            applyFlips(node, start, end);
            lazy[node] = 0; // mark the node no longer lazy
        }

        // If the node lies outside the input range,
        // then we simply return.
        if (right < start || end < left) {
            return;
        }

        // If the node lies fully inside the input range,
        // then we simply update the node and mark its children lazy.
        if (left <= start && end <= right) {
            applyFlips(node, start, end);
            return;
        }

        int leftChild = 2 * node + 1;
        int rightChild = 2 * node + 2;
        int mid = start + (end - start) / 2;

        flip(leftChild, start, mid, left, right);
        flip(rightChild, mid + 1, end, left, right);

        tree[node] = tree[leftChild] + tree[rightChild];
    }

    private void applyFlips(int node, int start, int end) {
        // apply the range flips to the node.
        tree[node] = (end - start + 1) - tree[node];

        // if its children exist then mark them as lazy.
        if (start < end) {
            // we use XOR operation to implement flip.
            lazy[2 * node + 1] ^= 1;
            lazy[2 * node + 2] ^= 1;
        }
    }
}

According to the above implementation, we can easily get the time complexity of each operation:

  • flip(int l, int r): O(log N)

  • update(int p): O(1)

  • query(): O(1)

Conclusion

Actually, the above problem can be solved by multiple methods, e.g. BitSet. But as you can see, with the help of a laze segment tree, we can solve this problem efficiently and elegantly.

If you have further interest in the alternative of Segment Tree, which is a Fenwick Tree, and the comparison between them, please refer to this article Fenwick Tree vs Segment Tree.