跳转至

树状数组

1、什么是树状数组

数组数组就是按照树的组织形式,存储数组中部分和的数据结构

树状数组一般适用于三类问题:

  1. 修改一个点求一个区间
  2. 修改一个区间求一个点
  3. 求逆序列对

树状数组可以解决的问题都可以用线段树解决,这两者的区别在哪里呢?树状数组的系数要少很多

image-20190928120050354

如上所示,如果每个节点对应的都有两个子节点,就是线段树

image-20190928120135926

树状数组则是更加精简,如上图所示,同一个节点会存在多个子节点

树状数组的关键点就在于如何保存节点之间的关系

1 1
2 10
3 11
4 100
5 101
6 110
7 111
8 1000

如上,有8个节点,

C[1] = A[1];
C[2] = A[1] + A[2];
C[3] = A[3];
C[4] = A[1] + A[2] + A[3] + A[4];
C[5] = A[5];
C[6] = A[5] + A[6];
C[7] = A[7];
C[8] = A[1] + A[2] + A[3] + A[4] + A[5] + A[6] + A[7] + A[8];

每一个节点保存当前节点的值,加上二进制中最右面的0个个数对应的元素值,

C[i] = A[i - 2k+1] + A[i - 2k+2] + ... + A[i];   //k为i的二进制中从最低位到高位连续零的长度
例如:
1, 没有0,只保存1
2,10,有1个0,保存当前值和这个0变成1的组合,也就是1
3, 没有0,只保存1
4,100,保存4和 00 的组合,也就是1,2,3

如果我们想要获取前面7个元素之和,只需要获取7,6,4对应的位置和即可

他们之间恰好有对应的关系

第一步:获取7中和值   
    111
第二步:去掉最低位的1,变成6,获取6中和值
    111 - 1 = 110
第三部,去掉低位的1,变成4,获取4中和值
    110 - 10 = 100

如何对元素进行更新呢?

假设需要更新第3个元素,只需要更新包含第三个元素的和值即可,也就是3,4,8

他们之间的关系是什么呢?

第一步:更新3
   11
第二步:3+ (3&-3) = 4, 也就是3加上最低的位1,因为上面的规则确定了相互之间的包含关系
   11 + 1 = 100
第三步:4 + (4&-4) = 8 
   100 + 100 = 1000

因此,线段树的父子关系的判断与树状数组不同,树状数组是通过二进制位来判断

2、树状数组实现

class BinaryIndexTree:
    def __init__(self, nums):
        """
        :type nums: List[int]
        """

        self.nums = nums
        self.buff = [0] * (len(nums) + 1)
        for i in range(1, len(nums) + 1):
            temp = nums[i - 1]
            while i < (len(nums) + 1):
                self.buff[i] += temp
                i += i & (-i)

    def update(self, i, val):
        """
        :type i: int
        :type val: int
        :rtype: void
        """
        index = i + 1
        diff = val - self.nums[i]
        while index < len(self.buff):
            self.buff[index] += diff
            index += index & (-index)

        self.nums[i] = val

    def sumRange(self, i, j):
        """
        :type i: int
        :type j: int
        :rtype: int
        """
        return self.getSum(j + 1) - self.getSum(i)

    def getSum(self, index):
        res = 0

        while index > 0:
            res += self.buff[index]
            index -= index & (-index)

        return res


a = [5, 4, 3, 2, 1, 5, 7, 7, 9]
seg = BinaryIndexTree(a)
print(seg.buff)

# assert query
for left in range(9):
    for right in range(left, 9):
        assert sum(a[left:right + 1]) == seg.sumRange(left, right)
[0, 5, 9, 3, 14, 1, 6, 7, 34, 9]