If you don't know what is a segment tree or how to implement a segment tree that supports range updates, please refer to these two articles: Efficiently Implement a Basic Segment Tree & Segment Tree with Lazy Propagation.
Put simply, a lazy segment tree (or segment tree with lazy propagation) can make a range query/update in O(log N)
time. In this article, we will discuss how we can use a lazy segment tree to solve a practical problem.
Problem
You are given two 0-indexed
int arrays bits
and nums
(0 <= bits[i] <= 1
, -10^9 <= nums[i] <= 10^9
, len(bits) == len(nums)
). Now design a data structure that supports the following three operations:
flip(int l, int r)
: Flip the values from 0 to 1 and from 1 to 0 inbits
from indexl
to indexr
. Bothl
andr
are 0-indexed and inclusive.update(int p)
: For every index0 <= i < len(nums)
, setnums[i] = nums[i] + bits[i] * p
.query()
: Return the sum of the elements innums
.
You should implement the data structure that supports the above functions in an efficient way (the original problem can be found on leetcode).
Solution
Intuition
The data structure supports two write operations (flip
& update
) and one read operation (query
):
Regarding the
flip
operation, it can flip a range of bits in thebits
array, which means that a lazy segment tree may be applicable in this place.Regarding the
update
operation, it can add an integerp
tonums[i]
for all0 <= i < len(nums)
ifbits[i] == 1
, which is equivalent to addingp * cardinality(bits)
to thesum(nums)
. Here,cardinality(bits)
returns the number of ones inbits
array. The good news is that the segment tree provides us with an efficient range query that can be used to implement thecardinality
operation. In fact,cardinality
is a special case of range query, as its range is always the full range[0, len(bits) - 1]
.Regarding the
query
operation, it just returns the current sum ofnums
. This operation should be trivial since we can cache the sum ofnums
after eachupdate
operation and return it directly when thequery
operation is called.
Implementation
public class Solution {
private final SegmentTree bitsTree;
private int sum;
public Solution(int[] bits, int[] nums) {
this.bitsTree = new SegmentTree(bits);
this.sum = getSum(nums);
}
public void flip(int l, int r) {
bitsTree.flip(l, r);
}
public void update(int p) {
sum += p * bitsTree.cardinality();
}
public int query() {
return sum;
}
private int getSum(int[] nums) {
int sum = 0;
for (int num : nums) {
sum += num;
}
return sum;
}
}
class SegmentTree {
private final int upperIdx;
private final int[] tree;
private final int[] lazy;
SegmentTree(int[] bits) {
this.upperIdx = bits.length - 1;
this.tree = new int[bits.length * 4];
this.lazy = new int[bits.length * 4];
this.buildTree(0, 0, upperIdx, bits);
}
void flip(int left, int right) {
flip(0, 0, upperIdx, left, right);
}
int cardinality() {
// tree[0] stores the number of ones for the entire array.
return tree[0];
}
private void buildTree(int node, int start, int end, int[] bits) {
if (start == end) {
tree[node] = bits[start];
return;
}
int leftChild = 2 * node + 1;
int rightChild = 2 * node + 2;
int mid = start + (end - start) / 2;
buildTree(leftChild, start, mid, bits);
buildTree(rightChild, mid + 1, end, bits);
// tree[node] stores the number of ones in range [start, end].
tree[node] = tree[leftChild] + tree[rightChild];
}
private void flip(int node, int start, int end, int left, int right) {
// non-zero means there are pending flips.
if (lazy[node] != 0) {
applyFlips(node, start, end);
lazy[node] = 0; // mark the node no longer lazy
}
// If the node lies outside the input range,
// then we simply return.
if (right < start || end < left) {
return;
}
// If the node lies fully inside the input range,
// then we simply update the node and mark its children lazy.
if (left <= start && end <= right) {
applyFlips(node, start, end);
return;
}
int leftChild = 2 * node + 1;
int rightChild = 2 * node + 2;
int mid = start + (end - start) / 2;
flip(leftChild, start, mid, left, right);
flip(rightChild, mid + 1, end, left, right);
tree[node] = tree[leftChild] + tree[rightChild];
}
private void applyFlips(int node, int start, int end) {
// apply the range flips to the node.
tree[node] = (end - start + 1) - tree[node];
// if its children exist then mark them as lazy.
if (start < end) {
// we use XOR operation to implement flip.
lazy[2 * node + 1] ^= 1;
lazy[2 * node + 2] ^= 1;
}
}
}
According to the above implementation, we can easily get the time complexity of each operation:
flip(int l, int r)
:O(log N)
update(int p)
:O(1)
query()
:O(1)
Conclusion
Actually, the above problem can be solved by multiple methods, e.g. BitSet
. But as you can see, with the help of a laze segment tree, we can solve this problem efficiently and elegantly.
If you have further interest in the alternative of Segment Tree, which is a Fenwick Tree, and the comparison between them, please refer to this article Fenwick Tree vs Segment Tree.