Segment Tree & Binary Indexed Tree (Fenwick Tree)
Overview
Segment Trees and Binary Indexed Trees (BIT/Fenwick Tree) are advanced data structures for efficiently handling range queries and updates on arrays.
Key Properties
- Time Complexity: O(log n) for both query and update operations
- Space Complexity: O(n) for BIT, O(4n) for Segment Tree
- Core Idea: Precompute range information in tree structure for fast queries
- When to Use: Range sum/min/max queries with updates, order statistics
- Key Operations: Range query, point/range update, build tree
Core Characteristics
- Range Queries: Sum, minimum, maximum, GCD, XOR over ranges
- Point Updates: Modify single element efficiently
- Range Updates: Modify entire ranges (with lazy propagation)
- Space-Time Tradeoff: Extra space for faster query processing
Problem Categories
Category 1: Range Sum Queries
- Description: Calculate sum over ranges with updates
- Examples: LC 307 (Range Sum Query - Mutable), LC 308 (Range Sum Query 2D - Mutable)
- Pattern: Use BIT or Segment Tree for point updates, range queries
Category 2: Range Minimum/Maximum Queries
- Description: Find min/max in ranges with updates
- Examples: LC 315 (Count of Smaller Numbers After Self), Custom RMQ problems
- Pattern: Segment Tree with min/max operations
Category 3: Range Updates with Lazy Propagation
- Description: Update entire ranges efficiently
- Examples: Add value to range, set range to value
- Pattern: Segment Tree with lazy propagation
Category 4: Order Statistics & Inversions
- Description: Count smaller/larger elements, inversions
- Examples: LC 315 (Count Smaller), LC 493 (Reverse Pairs), LC 327 (Count Range Sum)
- Pattern: BIT with coordinate compression or merge sort
Data Structure Comparison
BIT vs Segment Tree Comparison
| Aspect |
Binary Indexed Tree |
Segment Tree |
| Space |
O(n) |
O(4n) |
| Implementation |
Simple, short code |
More complex |
| Operations |
Sum, XOR, OR |
Any associative operation |
| Range Updates |
Difficult |
Easy with lazy propagation |
| 1-indexed |
Natural fit |
Can be adapted |
| Query Types |
Prefix queries easy |
Arbitrary range queries |
Templates & Algorithms
Template 1: Binary Indexed Tree (Fenwick Tree)
class BIT:
"""Binary Indexed Tree for range sum queries and point updates"""
def __init__(self, n):
self.n = n
self.tree = [0] * (n + 1) # 1-indexed
def update(self, i, delta):
"""Add delta to element at index i"""
while i <= self.n:
self.tree[i] += delta
i += i & (-i) # Add lowest set bit
def query(self, i):
"""Get prefix sum from 1 to i"""
total = 0
while i > 0:
total += self.tree[i]
i -= i & (-i) # Remove lowest set bit
return total
def range_query(self, left, right):
"""Get sum from left to right (inclusive)"""
if left > 1:
return self.query(right) - self.query(left - 1)
else:
return self.query(right)
def build(self, arr):
"""Build BIT from array (1-indexed)"""
for i in range(1, len(arr)):
self.update(i, arr[i])
Template 2: Segment Tree (Range Sum)
class SegmentTree:
"""Segment Tree for range sum queries and point updates"""
def __init__(self, arr):
self.n = len(arr)
self.tree = [0] * (4 * self.n) # 4x space for safety
self.build(arr, 1, 0, self.n - 1)
def build(self, arr, node, start, end):
"""Build segment tree recursively"""
if start == end:
self.tree[node] = arr[start]
else:
mid = (start + end) // 2
self.build(arr, 2 * node, start, mid)
self.build(arr, 2 * node + 1, mid + 1, end)
self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1]
def update(self, node, start, end, idx, val):
"""Update single element at index idx to val"""
if start == end:
self.tree[node] = val
else:
mid = (start + end) // 2
if idx <= mid:
self.update(2 * node, start, mid, idx, val)
else:
self.update(2 * node + 1, mid + 1, end, idx, val)
self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1]
def query(self, node, start, end, left, right):
"""Query sum in range [left, right]"""
if right < start or end < left:
return 0 # No overlap
if left <= start and end <= right:
return self.tree[node] # Complete overlap
# Partial overlap
mid = (start + end) // 2
left_sum = self.query(2 * node, start, mid, left, right)
right_sum = self.query(2 * node + 1, mid + 1, end, left, right)
return left_sum + right_sum
# Public interface methods
def point_update(self, idx, val):
"""Update element at index idx to val"""
self.update(1, 0, self.n - 1, idx, val)
def range_sum(self, left, right):
"""Get sum in range [left, right]"""
return self.query(1, 0, self.n - 1, left, right)
Template 3: Segment Tree with Lazy Propagation
class LazySegmentTree:
"""Segment Tree with lazy propagation for range updates"""
def __init__(self, arr):
self.n = len(arr)
self.tree = [0] * (4 * self.n)
self.lazy = [0] * (4 * self.n)
self.build(arr, 1, 0, self.n - 1)
def build(self, arr, node, start, end):
if start == end:
self.tree[node] = arr[start]
else:
mid = (start + end) // 2
self.build(arr, 2 * node, start, mid)
self.build(arr, 2 * node + 1, mid + 1, end)
self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1]
def push(self, node, start, end):
"""Push lazy value down to children"""
if self.lazy[node] != 0:
self.tree[node] += self.lazy[node] * (end - start + 1)
if start != end: # Not a leaf node
self.lazy[2 * node] += self.lazy[node]
self.lazy[2 * node + 1] += self.lazy[node]
self.lazy[node] = 0
def update_range(self, node, start, end, left, right, val):
"""Add val to range [left, right]"""
self.push(node, start, end)
if start > right or end < left:
return
if start >= left and end <= right:
self.lazy[node] += val
self.push(node, start, end)
return
mid = (start + end) // 2
self.update_range(2 * node, start, mid, left, right, val)
self.update_range(2 * node + 1, mid + 1, end, left, right, val)
self.push(2 * node, start, mid)
self.push(2 * node + 1, mid + 1, end)
self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1]
def query_range(self, node, start, end, left, right):
"""Query sum in range [left, right]"""
if start > right or end < left:
return 0
self.push(node, start, end)
if start >= left and end <= right:
return self.tree[node]
mid = (start + end) // 2
left_sum = self.query_range(2 * node, start, mid, left, right)
right_sum = self.query_range(2 * node + 1, mid + 1, end, left, right)
return left_sum + right_sum
# Public interface
def range_add(self, left, right, val):
"""Add val to range [left, right]"""
self.update_range(1, 0, self.n - 1, left, right, val)
def range_sum(self, left, right):
"""Get sum in range [left, right]"""
return self.query_range(1, 0, self.n - 1, left, right)
Template 4: 2D Binary Indexed Tree
class BIT2D:
"""2D Binary Indexed Tree for 2D range sum queries"""
def __init__(self, rows, cols):
self.rows = rows
self.cols = cols
self.tree = [[0] * (cols + 1) for _ in range(rows + 1)]
def update(self, row, col, delta):
"""Add delta to element at (row, col)"""
orig_col = col
while row <= self.rows:
col = orig_col
while col <= self.cols:
self.tree[row][col] += delta
col += col & (-col)
row += row & (-row)
def query(self, row, col):
"""Get sum from (1,1) to (row, col)"""
total = 0
orig_col = col
while row > 0:
col = orig_col
while col > 0:
total += self.tree[row][col]
col -= col & (-col)
row -= row & (-row)
return total
def range_query(self, row1, col1, row2, col2):
"""Get sum in rectangle from (row1, col1) to (row2, col2)"""
return (self.query(row2, col2) -
self.query(row1 - 1, col2) -
self.query(row2, col1 - 1) +
self.query(row1 - 1, col1 - 1))
LeetCode Problems & Solutions
Range Sum Query Problems
| Problem |
LC # |
Data Structure |
Difficulty |
Key Technique |
| Range Sum Query - Immutable |
303 |
Prefix Sum |
Easy |
Simple prefix array |
| Range Sum Query - Mutable |
307 |
BIT/Segment Tree |
Medium |
Point update, range query |
| Range Sum Query 2D - Immutable |
304 |
2D Prefix Sum |
Medium |
2D prefix array |
| Range Sum Query 2D - Mutable |
308 |
2D BIT |
Hard |
2D point update, range query |
Order Statistics Problems
| Problem |
LC # |
Data Structure |
Difficulty |
Key Technique |
| Count of Smaller Numbers After Self |
315 |
BIT + Compression |
Hard |
Coordinate compression |
| Reverse Pairs |
493 |
BIT/Merge Sort |
Hard |
Count inversions |
| Count of Range Sum |
327 |
BIT + Prefix Sum |
Hard |
Coordinate compression |
LC 307: Range Sum Query - Mutable
class NumArray:
"""Range Sum Query with updates using BIT"""
def __init__(self, nums):
self.nums = [0] + nums # Make 1-indexed
self.bit = BIT(len(nums))
# Build BIT
for i in range(1, len(self.nums)):
self.bit.update(i, self.nums[i])
def update(self, index, val):
"""Update element at index to val"""
index += 1 # Convert to 1-indexed
delta = val - self.nums[index]
self.nums[index] = val
self.bit.update(index, delta)
def sumRange(self, left, right):
"""Sum elements from left to right"""
return self.bit.range_query(left + 1, right + 1)
# Alternative using Segment Tree
class NumArraySegTree:
def __init__(self, nums):
self.seg_tree = SegmentTree(nums)
self.nums = nums
def update(self, index, val):
self.nums[index] = val
self.seg_tree.point_update(index, val)
def sumRange(self, left, right):
return self.seg_tree.range_sum(left, right)
LC 315: Count of Smaller Numbers After Self
def countSmaller(nums):
"""Count smaller numbers after self using BIT"""
if not nums:
return []
# Coordinate compression
sorted_nums = sorted(set(nums))
rank = {num: i + 1 for i, num in enumerate(sorted_nums)}
bit = BIT(len(sorted_nums))
result = []
# Process from right to left
for i in range(len(nums) - 1, -1, -1):
# Count numbers smaller than nums[i]
count = bit.query(rank[nums[i]] - 1) if rank[nums[i]] > 1 else 0
result.append(count)
# Add current number to BIT
bit.update(rank[nums[i]], 1)
return result[::-1] # Reverse to get correct order
# Alternative using merge sort
def countSmallerMergeSort(nums):
"""Using merge sort to count inversions"""
def mergeSort(arr):
if len(arr) <= 1:
return arr, [0] * len(arr)
mid = len(arr) // 2
left, left_counts = mergeSort(arr[:mid])
right, right_counts = mergeSort(arr[mid:])
merged = []
counts = [0] * len(arr)
i = j = 0
while i < len(left) and j < len(right):
if left[i][0] <= right[j][0]:
merged.append(left[i])
counts[left[i][1]] += j # j elements from right are smaller
i += 1
else:
merged.append(right[j])
j += 1
while i < len(left):
merged.append(left[i])
counts[left[i][1]] += j
i += 1
while j < len(right):
merged.append(right[j])
j += 1
return merged, counts
# Create (value, original_index) pairs
indexed_nums = [(nums[i], i) for i in range(len(nums))]
_, counts = mergeSort(indexed_nums)
return counts
LC 493: Reverse Pairs
def reversePairs(nums):
"""Count reverse pairs using BIT and coordinate compression"""
if not nums:
return 0
# Get all possible values (including doubled values)
values = set(nums)
for num in nums:
values.add(2 * num)
# Coordinate compression
sorted_values = sorted(values)
rank = {val: i + 1 for i, val in enumerate(sorted_values)}
bit = BIT(len(sorted_values))
count = 0
for num in reversed(nums):
# Count how many numbers > 2 * num are already seen
target_rank = rank[2 * num]
# Query from target_rank+1 to end
if target_rank < len(sorted_values):
count += bit.query(len(sorted_values)) - bit.query(target_rank)
# Add current number to BIT
bit.update(rank[num], 1)
return count
# Alternative merge sort approach
def reversePairsMergeSort(nums):
def mergeSort(arr, start, end):
if start >= end:
return 0
mid = (start + end) // 2
count = mergeSort(arr, start, mid) + mergeSort(arr, mid + 1, end)
# Count reverse pairs
j = mid + 1
for i in range(start, mid + 1):
while j <= end and arr[i] > 2 * arr[j]:
j += 1
count += j - (mid + 1)
# Merge sorted arrays
arr[start:end + 1] = sorted(arr[start:end + 1])
return count
return mergeSort(nums, 0, len(nums) - 1)
LC 327: Count of Range Sum
def countRangeSum(nums, lower, upper):
"""Count range sums in [lower, upper] using BIT"""
if not nums:
return 0
# Compute prefix sums
prefix_sums = [0]
for num in nums:
prefix_sums.append(prefix_sums[-1] + num)
# Get all relevant values for coordinate compression
values = set(prefix_sums)
for ps in prefix_sums:
values.add(ps - lower)
values.add(ps - upper)
sorted_values = sorted(values)
rank = {val: i + 1 for i, val in enumerate(sorted_values)}
bit = BIT(len(sorted_values))
count = 0
for ps in prefix_sums:
# Count prefix sums in range [ps - upper, ps - lower]
left_rank = rank[ps - upper]
right_rank = rank[ps - lower]
count += bit.range_query(left_rank, right_rank)
# Add current prefix sum to BIT
bit.update(rank[ps], 1)
return count
Advanced Techniques
Coordinate Compression
def coordinate_compress(arr):
"""Compress coordinates for BIT usage"""
unique_vals = sorted(set(arr))
rank_map = {val: i + 1 for i, val in enumerate(unique_vals)}
return rank_map, len(unique_vals)
def use_compression_example():
nums = [100, 1, 50, 200, 75]
rank_map, max_rank = coordinate_compress(nums)
# rank_map = {1: 1, 50: 2, 75: 3, 100: 4, 200: 5}
bit = BIT(max_rank)
for num in nums:
bit.update(rank_map[num], 1) # Add frequency
Range Maximum Query (RMQ) Segment Tree
class RMQSegmentTree:
"""Segment Tree for Range Maximum Queries"""
def __init__(self, arr):
self.n = len(arr)
self.tree = [0] * (4 * self.n)
self.build(arr, 1, 0, self.n - 1)
def build(self, arr, node, start, end):
if start == end:
self.tree[node] = arr[start]
else:
mid = (start + end) // 2
self.build(arr, 2 * node, start, mid)
self.build(arr, 2 * node + 1, mid + 1, end)
self.tree[node] = max(self.tree[2 * node], self.tree[2 * node + 1])
def query_max(self, node, start, end, left, right):
if right < start or end < left:
return float('-inf')
if left <= start and end <= right:
return self.tree[node]
mid = (start + end) // 2
left_max = self.query_max(2 * node, start, mid, left, right)
right_max = self.query_max(2 * node + 1, mid + 1, end, left, right)
return max(left_max, right_max)
def range_max(self, left, right):
return self.query_max(1, 0, self.n - 1, left, right)
Time Complexity Comparison
| Operation |
Naive Array |
BIT |
Segment Tree |
Sparse Table |
| Build |
O(1) |
O(n log n) |
O(n) |
O(n log n) |
| Point Update |
O(1) |
O(log n) |
O(log n) |
O(n) |
| Range Query |
O(n) |
O(log n) |
O(log n) |
O(1) |
| Range Update |
O(n) |
O(log n) |
O(log n) |
O(n) |
Space Complexity
- BIT: O(n) - very space efficient
- Segment Tree: O(4n) - needs more space but more flexible
- 2D BIT: O(n×m) - scales quadratically
- Lazy Segment Tree: O(4n) - same as regular segment tree
Implementation Tips
Common Pitfalls & Solutions
def bit_pitfalls():
"""Common BIT implementation mistakes"""
# ❌ Wrong: 0-indexed BIT
# BIT naturally works with 1-indexed arrays
# ✅ Correct: Convert to 1-indexed
def update_correct(bit, index, delta):
index += 1 # Convert 0-indexed to 1-indexed
while index <= bit.n:
bit.tree[index] += delta
index += index & (-index)
# ❌ Wrong: Forgetting coordination compression
def wrong_approach(nums):
bit = BIT(max(nums)) # Might use too much memory
# ✅ Correct: Use coordinate compression
def correct_approach(nums):
rank_map, size = coordinate_compress(nums)
bit = BIT(size)
for num in nums:
bit.update(rank_map[num], 1)
def segment_tree_tips():
"""Segment Tree best practices"""
# Use 4n space allocation for safety
tree = [0] * (4 * n)
# Handle edge cases properly
def query(node, start, end, left, right):
if right < start or end < left:
return 0 # Return identity element
# ... rest of query logic
Summary & Quick Reference
When to Use Each Structure
| Use Case |
Best Choice |
Why |
| Range Sum + Point Updates |
BIT |
Simple, space-efficient |
| Range Min/Max + Updates |
Segment Tree |
Supports any associative operation |
| Range Updates |
Lazy Segment Tree |
Efficient batch updates |
| 2D Range Queries |
2D BIT |
Natural extension |
| Count Inversions |
BIT + Compression |
Perfect for order statistics |
Implementation Checklist
- [ ] BIT: Remember 1-indexing, use coordinate compression for large values
- [ ] Segment Tree: Allocate 4n space, handle query edge cases
- [ ] Lazy Propagation: Implement push correctly, update children lazily
- [ ] 2D Structures: Consider memory usage, test with small examples first
LeetCode Problem Categories
- Range Sum: LC 303, 307, 308 (BIT/Segment Tree)
- Order Statistics: LC 315, 327, 493 (BIT + Compression)
- Dynamic Programming: Range DP with RMQ optimization
- Geometry: 2D range queries, rectangle problems
This comprehensive guide covers the essential concepts and implementations for Segment Trees and Binary Indexed Trees, with practical examples from LeetCode problems.