Introduction
Segment Tree data structure allows answering range queries over an array effectively, while still being flexible enough to allow modifying the array. This includes finding the sum of consecutive array elements a[l…r], or finding the minimum element in a such a range in O(logn) time.
Segment Tree is a basically a binary tree used for storing the intervals or segments. Each node in the Segment Tree represents an interval. Consider an array A of size N
and a corresponding Segment Tree T:
The root of T will represent the whole array A[0:N−1].
- Each leaf in the Segment Tree T will represent a single element A[i] such that 0≤i<N.
- The internal nodes in the Segment Tree T represents the union of elementary intervals A[i:j] where 0≤i<j<N.
Construction
An array representation of tree is used to represent Segment Trees. For each node at index i, the left child is at index 2*i+1, right child at 2*i+2. We start with a segment A[0 . . . N-1] and every time we divide the current segment into two halves (if it has not yet become a segment of length 1), and then call the same procedure on both halves, and for each such segment, we store the sum in the corresponding node. All levels of the constructed segment tree will be completely filled except the last level. Also, the tree will be a full binary tree because we always divide segments in two halves at every level. Since the constructed tree is always a full binary tree with N leaves, there will be N-1 internal nodes. So total number of nodes will be 2*N – 1. Note that this does not include dummy nodes. The time complexity of the construction is O(N).
If N is a power of 2, then there are no dummy nodes. So size of segment tree is 2N – 1 (N leaf nodes and N-1) internal nodes. If N is not a power of 2, then size of tree will be 2*X – 1 where X is smallest of 2 greater than N. For example, when N = 10, then size of array representing segment tree is 2*16-1 = 31.
// Construct Segment Tree for A[start..end] int buildSegmentTree(int A[], int start, int end, int []tree, int idx) { // One element in array, store it in current node if (start == end) { tree[idx] = A[start]; return A[start]; } // Recur for left and right segment int mid = start + (end - start)/2; tree[idx] = buildSegmentTree(A, start, mid, tree, idx * 2 + 1) + buildSegmentTree(A, mid+1, end, tree, idx * 2 + 2); return tree[idx]; }
Here is a visual representation of such a Segment Tree over the array a=[1,3,−2,8,−7]:
Application
Once the Segment Tree is built, its structure cannot be changed. We can update the values of nodes but we cannot change its structure. Segment tree provides two operations:
- Update: To update the element of the array A and reflect the corresponding change in the Segment tree.
- Query: In this operation we can query on an interval or segment and return the answer to the problem (say minimum/maximum/summation in the particular segment).
Now we want to modify a specific element in the array (update ), let’s say we want to do the assignment A[i]=x. And we have to rebuild the Segment Tree, such that it correspond to the new, modified array. Since each level of a Segment Tree forms a partition of the array. Therefore an element a[i] only contributes to one segment from each level. The function gets passed the current tree vertex, and it recursively calls itself with one of the two child vertices (the one that contains A[i] in its segment), and after that recomputes its sum value.
// Function to update the nodes which include A[i] from old_value to new_value void updateValueUtil(int []tree, int start, int end, int i, int diff, int idx) { // Input index lies out side the range if (i < start || i > end) return; // Update the value of the node and its children tree[idx] = tree[idx] + diff; if (end != start) { int mid = start + (end - start)/2; updateValueUtil(tree, start, mid, i, diff, 2*idx + 1); updateValueUtil(tree, mid+1, end, i, diff, 2*idx + 2); } }
To find the range sum (A[left…right]) of array, we will traverse the Segment Tree and use the precomputed sums of the segments. Let’s assume that we are currently at the node that covers the segment A[start…end]. There are three possible cases
- Case 1: The easiest case is when the segment A[left…right] is equal to the corresponding segment of the current vertex (i.e. A[start…end]=A[start…end]), then we are finished and can return the precomputed sum that is stored in the node.
- Case 2: Segment of the query can falls completely into the domain of either the left or the right child. In this case we go to the child vertex, which corresponding segment covers the query segment, and execute the algorithm described here with that node.
- Case 3: Query segment intersects with both children. In this case we have no other option as to make two recursive calls, one for each child. First we go to the left child, compute a partial answer for this vertex (i.e. the sum of values of the intersection between the segment of the query and the segment of the left child), then go to the right child, compute the partial answer using that vertex, and then combine the answers by adding them.
So processing a sum query is a function that recursively calls itself once with either the left or the right child (without changing the query boundaries), or twice, once for the left and once for the right child (by splitting the query into two subqueries). And the recursion ends, whenever the boundaries of the current query segment coincides with the boundaries of the segment of the current vertex. In that case the answer will be the precomputed value of the sum of this segment, which is stored in the tree.
// Function to get the sum of values in given range A[left....right] from Segment Tree, tree // Initially 'start = 0' and 'end = N' int getSumUtil(int [] tree, int start, int end, int left, int right, int idx) { // Segment of this node is a part of given range, return sum of the segment if (left <= start && right >= end) { return tree[idx]; } // Segment of this node is outside the given range if (end < left || start > right) { return 0; } // Segment overlaps with the given range, check for left and right of the node int mid = start + (end - start)/2; return getSumUtil(tree, start, mid, left, right, 2*idx+1) + getSumUtil(tree, mid+1, end, left, right, 2*idx+2); }