跳转至

线段树

1、什么是线段树?

线段树是一种数据结构,主要针对于区间的查询与更新的快速操作

例如下面的问题:

假设有长度为10w的数组,定义两种操作:
- 查询区间[i,j]中所有元素的和
- 更新位置i上面的数字

方式1:

  • 如果每次查询都计算一遍,那么操作1的时间复杂度为O(n)

  • 插入的时间复杂度为O(1)

方式2:

保存数组前缀和

  • 每次查询只需要计算sum(j)-sum(i-1),时间复杂度为O(1)
  • 每次更新数字,就需要将包含这个数字的所有前缀和都更新出来,时间复杂度O(n)

由于上面两种方式的缺陷,因此使用一个折中的方式:线段树

image-20190928095504544

如上,假设现在有上面的一个数组,其中的元素如上所示

image-20190928095827298

采用上面的方式存储元素,每一个元素表示他所指向的下标范围中的数字之和

如何实现查询操作呢?

  • 例如实现查询[3,5]之间的和

image-20190928100749994

首先从根节点出发,查找[3,5]之间的和,发现分布在两个子树中,继续向下查找

直到需要查找的区间与当前节点代表的区间相同为止

  • 更新节点的值
arr[3]=8

image-20190928102438407

如上所示,我们只需要将区间包含3的更新即可

2、线段树的实现

不需要直接构建二叉树,而是利用数组的下标之间的关系保存

image-20190928103550478

例如,根节点的左右节点分别为:

left_node = 2 * node + 1
right_node = 2 * node + 2

直接利用这种关系保存即可

代码示例如下:

import math


class SegmentTree:
    def __init__(self, arr):
        self.arr = arr
        self.arr_length = len(arr)
        self.tree = [0] * self.get_seg_tree_length(self.arr_length)
        self.build_tree(0, 0, self.arr_length - 1)

    def build_tree(self, node, start, end):
        if start == end:
            self.tree[node] = self.arr[start]
        else:
            mid = (start + end) // 2
            left_node = 2 * node + 1
            right_node = 2 * node + 2
            self.build_tree(left_node, start, mid)
            self.build_tree(right_node, mid + 1, end)
            self.tree[node] = self.tree[left_node] + self.tree[right_node]

    def query(self, left, right):
        return self._query_segment(0, 0, self.arr_length - 1, left, right)

    def _query_segment(self, node, start, end, left, right):
        if start == left and right == end or start == end:
            return self.tree[node]

        if left > end or right < start:
            return 0
        mid = (start + end) // 2
        left_node = 2 * node + 1
        right_node = 2 * node + 2

        if right <= mid:
            return self._query_segment(left_node, start, mid, left, right)
        elif left > mid:
            return self._query_segment(right_node, mid + 1, end, left, right)
        else:
            sum_left = self._query_segment(left_node, start, mid, left, mid)
            sum_right = self._query_segment(right_node, mid + 1, end, mid + 1, right)
            return sum_left + sum_right

    def update(self, index, val):
        diff = val - self.arr[index]
        self.arr[index] = val
        if diff != 0:
            self._update_index(0, 0, self.arr_length - 1, index, diff)

    def _update_index(self, node, start, end, index, diff):
        self.tree[node] += diff
        if start != end:
            mid = (start + end) // 2
            left_node = 2 * node + 1
            right_node = 2 * node + 2
            if index <= mid:
                self._update_index(left_node, start, mid, index, diff)
            else:
                self._update_index(right_node, mid + 1, end, index, diff)

    @staticmethod
    def get_seg_tree_length(arr_length):
        return 2 ** (math.ceil(math.log(arr_length, 2)) + 1) - 1


a = [5, 4, 3, 2, 1, 5, 7, 7, 9]
seg = SegmentTree(a)
print(seg.tree)

# assert query
for i in range(9):
    for j in range(i, 9):
        assert sum(a[i:j + 1]) == seg.query(i, j)

seg.update(3, 10)
print(seg.tree)
# assert update query
for i in range(9):
    for j in range(i, 9):
        assert sum(a[i:j + 1]) == seg.query(i, j)
# 线段树数组
# 更新之前
[43, 15, 28, 12, 3, 12, 16, 9, 3, 2, 1, 5, 7, 7, 9, 5, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
# 更新之后
[51, 23, 28, 12, 11, 12, 16, 9, 3, 10, 1, 5, 7, 7, 9, 5, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]