# 线段树¶

## 1、什么是线段树？¶

假设有长度为10w的数组，定义两种操作：
- 查询区间[i,j]中所有元素的和
- 更新位置i上面的数字


• 如果每次查询都计算一遍，那么操作1的时间复杂度为$O(n)$

• 插入的时间复杂度为$O(1)$

• 每次查询只需要计算$sum(j)-sum(i-1)$，时间复杂度为$O(1)$
• 每次更新数字，就需要将包含这个数字的所有前缀和都更新出来，时间复杂度$O(n)$

• 例如实现查询[3,5]之间的和

• 更新节点的值
arr[3]=8


## 2、线段树的实现¶

left_node = 2 * node + 1
right_node = 2 * node + 2


import math

class SegmentTree:
def __init__(self, arr):
self.arr = arr
self.arr_length = len(arr)
self.tree = [0] * self.get_seg_tree_length(self.arr_length)
self.build_tree(0, 0, self.arr_length - 1)

def build_tree(self, node, start, end):
if start == end:
self.tree[node] = self.arr[start]
else:
mid = (start + end) // 2
left_node = 2 * node + 1
right_node = 2 * node + 2
self.build_tree(left_node, start, mid)
self.build_tree(right_node, mid + 1, end)
self.tree[node] = self.tree[left_node] + self.tree[right_node]

def query(self, left, right):
return self._query_segment(0, 0, self.arr_length - 1, left, right)

def _query_segment(self, node, start, end, left, right):
if start == left and right == end or start == end:
return self.tree[node]

if left > end or right < start:
return 0
mid = (start + end) // 2
left_node = 2 * node + 1
right_node = 2 * node + 2

if right <= mid:
return self._query_segment(left_node, start, mid, left, right)
elif left > mid:
return self._query_segment(right_node, mid + 1, end, left, right)
else:
sum_left = self._query_segment(left_node, start, mid, left, mid)
sum_right = self._query_segment(right_node, mid + 1, end, mid + 1, right)
return sum_left + sum_right

def update(self, index, val):
diff = val - self.arr[index]
self.arr[index] = val
if diff != 0:
self._update_index(0, 0, self.arr_length - 1, index, diff)

def _update_index(self, node, start, end, index, diff):
self.tree[node] += diff
if start != end:
mid = (start + end) // 2
left_node = 2 * node + 1
right_node = 2 * node + 2
if index <= mid:
self._update_index(left_node, start, mid, index, diff)
else:
self._update_index(right_node, mid + 1, end, index, diff)

@staticmethod
def get_seg_tree_length(arr_length):
return 2 ** (math.ceil(math.log(arr_length, 2)) + 1) - 1

a = [5, 4, 3, 2, 1, 5, 7, 7, 9]
seg = SegmentTree(a)
print(seg.tree)

# assert query
for i in range(9):
for j in range(i, 9):
assert sum(a[i:j + 1]) == seg.query(i, j)

seg.update(3, 10)
print(seg.tree)
# assert update query
for i in range(9):
for j in range(i, 9):
assert sum(a[i:j + 1]) == seg.query(i, j)

# 线段树数组
# 更新之前
[43, 15, 28, 12, 3, 12, 16, 9, 3, 2, 1, 5, 7, 7, 9, 5, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
# 更新之后
[51, 23, 28, 12, 11, 12, 16, 9, 3, 10, 1, 5, 7, 7, 9, 5, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]