zl程序教程

您现在的位置是:首页 >  其他

当前栏目

线段树模板

模板 线段
2023-06-13 09:14:21 时间

线段树模板

线段树是算法竞赛中常用的用来维护 区间信息 的数据结构。

线段树可以在

的时间复杂度内实现单点修改、区间修改、区间查询(区间求和,求区间最大值,求区间最小值)等操作。

线段树 + Lazy(数组)

class SegmentTree:
    def __init__(self, nums) -> None:
        self.n = len(nums)
        self.nums = nums
        self.tree = [0] * (4 * self.n)
        self.lazy = [0] * (4 * self.n)
        self.build(1, self.n, 1)

    def build(self, start, end, idx):
        # 对 [start, end] 区间建立线段树,当前根的编号为 idx
        if start == end:
            self.tree[idx] = self.nums[start - 1]
            return
        mid = start + ((end - start) >> 1)
        # 递归对左右区间建树
        self.build(start, mid, idx << 1)
        self.build(mid + 1, end, idx << 1 | 1)
        # 合并左右区间的结果
        self.pushup(idx)

    def query(self, start, end, idx, left, right):
        # [s, t] 为当前节点包含的区间, 当前根的编号为 idx
        # 查询 [left, right] 区间的结果

        # 当前区间为询问区间的子集时直接返回当前区间的和
        if left <= start and right >= end:
            return self.tree[idx]
        mid, sum = start + ((end - start) >> 1), 0
        self.pushdown(idx, mid - start + 1, end - mid)
        # 如果询问区间在左区间内,则递归查询左区间
        if left <= mid:
            sum += self.query(start, mid, idx << 1, left, right)
        # 如果询问区间在右区间内,则递归查询右区间
        if right > mid:
            sum += self.query(mid + 1, end, idx << 1 | 1, left, right)
        return sum

    def update(self, start, end, idx, left, right, val):
        # [s, t] 为当前节点包含的区间, 当前根的编号为 idx
        # 更新 [left, right] 区间的结果, 区间加上值 val

        # 当前区间为修改区间的子集时直接修改当前节点的值, 然后打标记, 结束修改
        if left <= start and right >= end:
            self.tree[idx] += (end - start + 1) * val
            self.lazy[idx] += val
            return
        mid = start + ((end - start) >> 1)
        self.pushdown(idx, mid - start + 1, end - mid)
        # 如果修改区间在左区间内,则递归更新左区间
        if left <= mid:
            self.update(start, mid, idx << 1, left, right, val)
        # 如果修改区间在右区间内,则递归更新右区间
        if right > mid:
            self.update(mid + 1, end, idx << 1 | 1, left, right, val)
        # 合并左右区间的结果
        self.pushup(idx)

    def pushup(self, idx):
        # 从儿子节点更新当前节点
        self.tree[idx] = self.tree[idx << 1] + self.tree[idx << 1 | 1]

    def pushdown(self, idx, ln, rn):
        # 当前根的编号为 idx, ln, rn 分别表示左右子树的节点数量
        # 从父节点更新当前节点, 下放懒惰标记
        if self.lazy[idx] != 0:
            # 更新当前节点两个子节点的值
            self.tree[idx << 1] += self.lazy[idx] * ln
            self.tree[idx << 1 | 1] += self.lazy[idx] * rn
            # 将标记下传给子节点
            self.lazy[idx << 1] += self.lazy[idx]
            self.lazy[idx << 1 | 1] += self.lazy[idx]
            # 清空当前节点的标记
            self.lazy[idx] = 0

线段树 + Lazy + 动态开点(类)

class SegmentTree:
    class Node:
        def __init__(self):
            self.left = None
            self.right = None
            self.val = 0
            self.lazy = 0

    def __init__(self) -> None:
        self.root = self.Node()

    @staticmethod
    def query(start: int, end: int, node: Node, left: int, right: int) -> int:
        # [s, t] 为当前节点包含的区间, 当前根为 node
        # 查询 [left, right] 区间的结果

        # 当前区间为询问区间的子集时直接返回当前区间的和
        if left <= start and right >= end:
            return node.val
        mid, sum = start + ((end - start) >> 1), 0
        SegmentTree.pushdown(node, mid - start + 1, end - mid)
        if left <= mid:
            sum += SegmentTree.query(start, mid, node.left, left, right)
        if right > mid:
            sum += SegmentTree.query(mid + 1, end, node.right, left, right)
        return sum

    @staticmethod
    def update(start: int, end: int, node: Node, left: int, right: int, val: int) -> None:
        # [s, t] 为当前节点包含的区间, 当前根为 node
        # 更新 [left, right] 区间值为 val

        # 当前区间为修改区间的子集时直接修改当前节点的值, 然后打标记, 结束修改
        if left <= start and right >= end:
            node.val += val * (end - start + 1)
            node.lazy += val
            return
        mid = start + ((end - start) >> 1)
        SegmentTree.pushdown(node, mid - start + 1, end - mid)
        if left <= mid:
            SegmentTree.update(start, mid, node.left, left, right, val)
        if right > mid:
            SegmentTree.update(mid + 1, end, node.right, left, right, val)
        SegmentTree.pushup(node)

    @staticmethod
    def pushup(node: Node):
        node.val = node.left.val + node.right.val

    @staticmethod
    def pushdown(node: Node, ln: int, rn: int):
        if node.left is None:
            node.left = SegmentTree.Node()
        if node.right is None:
            node.right = SegmentTree.Node()
        if node.lazy:
            # 更新当前节点两个子节点的值
            node.left.val += node.lazy * ln
            node.right.val += node.lazy * rn
            # 将标记下传给子节点
            node.left.lazy += node.lazy
            node.right.lazy += node.lazy
            # 清空当前节点的标记
            node.lazy = 0

参考资料