zl程序教程

您现在的位置是:首页 >  后端

当前栏目

【高级数据结构】线段树 | 求区间和

数据结构 高级 线段 区间
2023-09-27 14:28:31 时间

线段树是一种二叉搜索树,与区间树相似,它将一个区间划分成一些单元区间,每个单元区间对应线段树中的一个叶结点。

参考:https://www.bilibili.com/video/BV1cb411t7AM?from=search&seid=7296088207308787099&spm_id_from=333.337.0.0 视频,文中部分图片取自该视频截图 [侵权删除],这里感谢up主 正月点灯笼 的视频讲解。


问题描述:用于数组的区间求和,且数组内的元素可修改

对于数组 arr = {1,2,3,4,5,6,6,7,…} 而言,要求得 i~j 下标数据的和,有以下几种方法。

  1. 我们需要遍历数组求和。时间复杂度O(n)。而修改数组中某个元素的值,时间复杂度为O(1)。

  2. 我们可以使用一个 sum_arr 前缀和数组来保存 0~i 位置的和,这样一来,求 i~j 的区间和等同于 arr[j] - arr[i-1]。
    时间复杂度为 O(1)。 而修改数组中某个元素的值后,sum_arr的结果都需要修改,时间复杂度为O(n)。

  3. 这里可以使用前缀树。根节点保存的是 0~size 的所有元素之和,
    左区间保存 0~mid 的元素和, 右区间保存 mid+1 ~ size 的和。其他同理。
    在这里插入图片描述

求区间和还有一种方法,使用树状数组(又名“二叉索引树”(Binary Indexed Tree)或Fenwick树)。参考:https://blog.dotcpp.com/a/78973
选取部分内容:
在这里插入图片描述

查询方法:先考虑[1, r]的求和, 从右往左取块,将块代表的数值加起来即可
.
图中的例子:欲求区间 [1,13] 的和

  • 第一次取到13,长度为lowbit(13) = 1
  • 第二次13取完了从12开始取,长度为4,一次性将[9, 12]取完
  • 第三次[9, 13]取完了从8开始,长度为8,取走[1, 8],到此[1, 13]全部取走
    .
    数据结果可视化模拟:https://visualgo.net/zh/fenwicktree
    在这里插入图片描述

线段树query操作:如果我们要计算 [2,5] 的区间和,左边找到[2]节点,右边找到[3,5]节点,因此[2,5]区间和等于 5+27。

时间复杂度最坏情况下找到最后一个叶子节点,所以是O(logn)
在这里插入图片描述
update操作:修改某个叶子结点的值后,顺着该节点向上更新即可,时间复杂度O(logn)
在这里插入图片描述


实现:使用完全二叉树结构保存。
在这里插入图片描述

#include<stdio.h>
#include<stdlib.h>

int max_node = 0;	/* 记录线段树的真实最大长度 */
// 将数组元素构建一颗线段树
void build_tree(int arr[], int tree[], int node, int start, int end)
{   
    if (start == end)
    {   // 叶子节点,存放arr中的值
        tree[node] = arr[start];
        // 记录,线段树最大长度
        max_node = std::max(max_node, node);
    }
    else {
        int mid = (start + end) / 2;
        int left_node = 2 * node + 1;
        int right_node = 2 * node + 2;

        build_tree(arr, tree, left_node, start, mid);
        build_tree(arr, tree, right_node, mid + 1, end);
        tree[node] = tree[left_node] + tree[right_node];	// 非叶子节点,存放下边两孩子的和
    }
}
// 更新操作
void update_tree(int arr[], int tree[], int node, int start, int end, int id, int val)
{// start==end -》 叶子结点, node 表示该节点在 tree中的下标位置
    if (start == end) {
        arr[id] = val;
        tree[node] = val;
    }
    else {
        int mid = (start + end) / 2;
        int left_node = 2 * node + 1;
        int right_node = 2 * node + 2;
        // 在左边区间内,向左子树查找
        if (id >= start && id <= mid) {
            update_tree(arr, tree, left_node, start, mid, id, val);
        } 
        else if(id <= end){
            update_tree(arr, tree, right_node, mid + 1, end, id, val);
        }
        tree[node] = tree[left_node] + tree[right_node];	// 更新上层节点的值
    }
}
// 查询区间[L,R]的和。start,end表示原arr数字的范围
int query_tree(int arr[], int tree[], int node, int start, int end, int L, int R)
{
    // 所求区间在当前区间左边,或右边。则无需继续向下遍历
    if (start > R || end < L)
        return 0;
    // 优化,当前访问区间在,[L,R]内部,则无需向下继续寻找了
    else if (L <= start && end <= R) {
        return tree[node];
    }
    //if (L == R) return arr[L];
    else if (start == end) {
        return tree[node];
    }

    int mid = (start + end) / 2;
    int left_node = 2 * node + 1;
    int right_node = 2 * node + 2;
    int sum_left = query_tree(arr, tree, left_node, start, mid, L, R);
    int sum_right = query_tree(arr, tree, right_node, mid + 1, end, L, R);

    return sum_left + sum_right;
}

/* 将数据打印,以表格的形式。 注:该函数与当前项目无关,忽略即可 */
void Print_table_frame(int arr[], int i, int n, const char* table_name = "arr数组", int size = 3)
{
    if (nullptr == arr || i < 0) return;
    if (size < 2) size = 2; 

    char buf[21] = "━━━━━━━━━━";
    buf[2 * (size+1)] = '\0';
    char space[] = "";

    // 打印数组元素
    printf("%s元素为: \n┎━", table_name);//┏┻
    for (int id = i; id <= n; id++)  printf("%s┒", buf);
    printf("\n┃%*s下标", size-2, space);
    for (int id = i; id < n; id++)  printf("┃ %*d", size, id);
    printf("┃\n┃━");
    for (int id = i; id < n; id++)  printf("%s╂", buf);
    printf("%s┃\n┃%*s元素", buf, size-2, space);
    for (int id = i; id < n; id++)  printf("┃ %*d", size, arr[id]);
    printf("┃\n┖━");
    for (int id = i; id < n; id++)  printf("%s┸", buf);
    printf("%s┚\n", buf);
}

int main(void)
{
    int arr[] = { 1,2,3,4,5,6,7,8,9,10,9,-8 };      /* 数组 ,100, 124,125 */
    constexpr int n = sizeof(arr) / sizeof(arr[0]);
    int tree[n * 4] = { 0 };        /* 线段树需要开4倍区间不会越界 */

    build_tree(arr, tree, 0, 0, n-1);            /* 构建线段树 */
    
    //arr[4] = 6;
    //update_tree(arr, tree, 0, 0, n-1, 4, 6);     /* 修改数据: arr[4] = 6   */
    update_tree(arr, tree, 0, 0, n - 1, 0, 6);

    Print_table_frame(arr, 0, n, "arr数组");
    Print_table_frame(tree, 0, max_node+1, "tree", 3);


    int i = 6, j = 6;
    int res = query_tree(arr, tree, 0, 0, n-1, i, j); /* 求i~j的和*/

    int sum = 0;
    for (; i <= j; ++i) sum += arr[i];
    printf("%d , %d \n", sum, res);

    return 0;
}

为了便于使用,我们将上述代码封装成类的形式。

// 使用类封装,可以继承,或包含使用。

// 以表格的形式打印数据
void Print_table_frame(const vector<int>& arr, const char* table_name = "arr数组", int size = 3)
{
    if (arr.empty()) return;
    if (size < 2) size = 2;

    char buf[21] = "━━━━━━━━━━";
    buf[2 * (size + 1)] = '\0';
    char space[] = "";
    int i = 0;
    int n = arr.size();
    // 打印数组元素
    printf("%s元素为: \n┎━", table_name);//┏┻
    for (int id = i; id <= n; id++)  printf("%s┒", buf);
    printf("\n┃%*s下标", size - 2, space);
    for (int id = i; id < n; id++)  printf("┃ %*d", size, id);
    printf("┃\n┃━");
    for (int id = i; id < n; id++)  printf("%s╂", buf);
    printf("%s┃\n┃%*s元素", buf, size - 2, space);
    for (int id = i; id < n; id++)  printf("┃ %*d", size, arr[id]);
    printf("┃\n┖━");
    for (int id = i; id < n; id++)  printf("%s┸", buf);
    printf("%s┚\n", buf);
}


class SegmentTree {
private:
    vector<int>& arr;   /* 原数组 */
    vector<int> tree;   /* 线段树数组 */
    int max_id;         /* 线段树长度 ,初始长度为原数组长度 */
public:
    SegmentTree(vector<int>& vec)
        :arr(vec),  /* 引用该数组 */
        tree(vec.size()*4, 0),
        max_id(vec.size() - 1)
    {
        build_tree(0, 0, max_id);   /* 构建线段树 */
        tree.resize(max_id + 1);    /* tree真正的长度*/
    }
    int query(int L, int R)
    {
        if (L == R) return arr[L];
        return query_tree(0, 0, arr.size() - 1, L, R);
    }
    void update(int index, int val)
    {
        if (index < 0 || index >= arr.size()) return;
        update_tree(0, 0, arr.size() - 1, index, val);
    }
    const vector<int>& getTreeArr() const
    {
        return tree;
    }
private:
    void build_tree(int node, int start, int end)
    {
        if (start == end)
        {   // 叶子节点,存放arr中的值
            tree[node] = arr[start];
            // 记录,线段树最大长度
            max_id = std::max(max_id, node);
        }
        else {
            int mid = (start + end) / 2;
            int left_node = 2 * node + 1;
            int right_node = 2 * node + 2;

            build_tree(left_node, start, mid);
            build_tree(right_node, mid + 1, end);
            tree[node] = tree[left_node] + tree[right_node];	// 非叶子节点,存放下边两孩子的和
        }
    }
    int query_tree(int node, int start, int end, int L, int R)
    {
        // 所求区间在当前区间左边,或右边。则无需继续向下遍历
        if (start > R || end < L)
            return 0;
        // 优化,当前访问区间在,[L,R]内部,则无需向下继续寻找了
        else if (L <= start && end <= R) {
            return tree[node];
        }
        //if (L == R) return arr[L];
        else if (start == end) {
            return tree[node];
        }

        int mid = (start + end) / 2;
        int left_node = 2 * node + 1;
        int right_node = 2 * node + 2;
        int sum_left = query_tree(left_node, start, mid, L, R);
        int sum_right = query_tree(right_node, mid + 1, end, L, R);

        return sum_left + sum_right;
    }
    void update_tree(int node, int start, int end, int id, int val)
    {// start==end -》 叶子结点, node 表示该节点在 tree中的下标位置
        if (start == end) {
            arr[id] = val;
            tree[node] = val;
        }
        else {
            int mid = (start + end) / 2;
            int left_node = 2 * node + 1;
            int right_node = 2 * node + 2;
            // 在左边区间内,向左子树查找
            if (id >= start && id <= mid) {
                update_tree(left_node, start, mid, id, val);
            }
            else if (id <= end) {
                update_tree(right_node, mid + 1, end, id, val);
            }
            tree[node] = tree[left_node] + tree[right_node];	// 更新上层节点的值
        }
    }
};

int main(void)
{
    vector<int> arr = { 1,2,3,4,5,6,7,8,9,10,9,-8 };
    SegmentTree tree(arr);

    Print_table_frame(arr, "arr数组");
    Print_table_frame(tree.getTreeArr(), "tree数组");

    int i = 6, j = 6; 
    int res = tree.query(i, j); /* 求i~j的和*/

    int sum = 0;
    for (; i <= j; ++i) sum += arr[i];
    printf("%d , %d \n", sum, res);


    tree.update(6, 100);
    Print_table_frame(arr, "arr数组");
    Print_table_frame(tree.getTreeArr(), "tree数组");

    i = 6, j = 7;
    res = tree.query(i, j); /* 求i~j的和*/

    sum = 0;
    for (; i <= j; ++i) sum += arr[i];
    printf("%d , %d \n", sum, res);

    return 0;
}

在这里插入图片描述


练习:307. 区域和检索 - 数组可修改
https://leetcode-cn.com/problems/range-sum-query-mutable/submissions/

class SegmentTree;		/* 自实现的线段树,参考👆 */

/* 包含的方式使用:在 NumArray 类中包含 SegmentTree对象,然后使用SegmentTree对象的方法。 */
class NumArray {
    SegmentTree tree;
public:
    NumArray(vector<int>& nums)
    :tree(nums)
    {
    }
    
    void update(int index, int val) {
        tree.update(index, val);
    }
    
    int sumRange(int left, int right) {
        return tree.query(left, right);
    }
};

/* 继承的方式使用:继承SegmentTree类,就可以使用SegmentTree的方法了。 */
class NumArray: protected SegmentTree {	
/* 这里使用保护继承,因此,SegmentTree类中的方法在NumArray类中都是保护的。这样就可以避免NumArrayd的类对象在类外部使用SegmentTree 的方法。只暴露NumArray本身提供的方法。*/
public:
    NumArray(vector<int>& nums)
    :SegmentTree(nums)
    {
    }
    
    void update(int index, int val) {
        SegmentTree::update(index, val);
    }
    
    int sumRange(int left, int right) {
        return SegmentTree::query(left, right);
    }
};

/**
 * Your NumArray object will be instantiated and called as such:
 * NumArray* obj = new NumArray(nums);
 * obj->update(index,val);
 * int param_2 = obj->sumRange(left,right);
 */