【高级数据结构】线段树 | 求区间和
线段树是一种二叉搜索树,与区间树相似,它将一个区间划分成一些单元区间,每个单元区间对应线段树中的一个叶结点。
参考:https://www.bilibili.com/video/BV1cb411t7AM?from=search&seid=7296088207308787099&spm_id_from=333.337.0.0 视频,文中部分图片取自该视频截图 [侵权删除],这里感谢up主 正月点灯笼 的视频讲解。
问题描述:用于数组的区间求和,且数组内的元素可修改。
对于数组 arr = {1,2,3,4,5,6,6,7,…} 而言,要求得 i~j 下标数据的和,有以下几种方法。
-
我们需要遍历数组求和。时间复杂度O(n)。而修改数组中某个元素的值,时间复杂度为O(1)。
-
我们可以使用一个 sum_arr 前缀和数组来保存 0~i 位置的和,这样一来,求 i~j 的区间和等同于 arr[j] - arr[i-1]。
时间复杂度为 O(1)。 而修改数组中某个元素的值后,sum_arr的结果都需要修改,时间复杂度为O(n)。 -
这里可以使用前缀树。根节点保存的是 0~size 的所有元素之和,
左区间保存 0~mid 的元素和, 右区间保存 mid+1 ~ size 的和。其他同理。
求区间和还有一种方法,使用树状数组(又名“二叉索引树”(Binary Indexed Tree)或Fenwick树)。参考:https://blog.dotcpp.com/a/78973
选取部分内容:
查询方法:先考虑
[1, r]
的求和, 从右往左取块,将块代表的数值加起来即可
.
图中的例子:欲求区间[1,13]
的和
- 第一次取到13,长度为lowbit(13) = 1
- 第二次13取完了从12开始取,长度为4,一次性将[9, 12]取完
- 第三次[9, 13]取完了从8开始,长度为8,取走[1, 8],到此[1, 13]全部取走
.
数据结果可视化模拟:https://visualgo.net/zh/fenwicktree
线段树query操作:如果我们要计算 [2,5] 的区间和,左边找到[2]节点,右边找到[3,5]节点,因此[2,5]区间和等于 5+27。
时间复杂度最坏情况下找到最后一个叶子节点,所以是O(logn)
update操作:修改某个叶子结点的值后,顺着该节点向上更新即可,时间复杂度O(logn)
实现:使用完全二叉树结构保存。
#include<stdio.h>
#include<stdlib.h>
int max_node = 0; /* 记录线段树的真实最大长度 */
// 将数组元素构建一颗线段树
void build_tree(int arr[], int tree[], int node, int start, int end)
{
if (start == end)
{ // 叶子节点,存放arr中的值
tree[node] = arr[start];
// 记录,线段树最大长度
max_node = std::max(max_node, node);
}
else {
int mid = (start + end) / 2;
int left_node = 2 * node + 1;
int right_node = 2 * node + 2;
build_tree(arr, tree, left_node, start, mid);
build_tree(arr, tree, right_node, mid + 1, end);
tree[node] = tree[left_node] + tree[right_node]; // 非叶子节点,存放下边两孩子的和
}
}
// 更新操作
void update_tree(int arr[], int tree[], int node, int start, int end, int id, int val)
{// start==end -》 叶子结点, node 表示该节点在 tree中的下标位置
if (start == end) {
arr[id] = val;
tree[node] = val;
}
else {
int mid = (start + end) / 2;
int left_node = 2 * node + 1;
int right_node = 2 * node + 2;
// 在左边区间内,向左子树查找
if (id >= start && id <= mid) {
update_tree(arr, tree, left_node, start, mid, id, val);
}
else if(id <= end){
update_tree(arr, tree, right_node, mid + 1, end, id, val);
}
tree[node] = tree[left_node] + tree[right_node]; // 更新上层节点的值
}
}
// 查询区间[L,R]的和。start,end表示原arr数字的范围
int query_tree(int arr[], int tree[], int node, int start, int end, int L, int R)
{
// 所求区间在当前区间左边,或右边。则无需继续向下遍历
if (start > R || end < L)
return 0;
// 优化,当前访问区间在,[L,R]内部,则无需向下继续寻找了
else if (L <= start && end <= R) {
return tree[node];
}
//if (L == R) return arr[L];
else if (start == end) {
return tree[node];
}
int mid = (start + end) / 2;
int left_node = 2 * node + 1;
int right_node = 2 * node + 2;
int sum_left = query_tree(arr, tree, left_node, start, mid, L, R);
int sum_right = query_tree(arr, tree, right_node, mid + 1, end, L, R);
return sum_left + sum_right;
}
/* 将数据打印,以表格的形式。 注:该函数与当前项目无关,忽略即可 */
void Print_table_frame(int arr[], int i, int n, const char* table_name = "arr数组", int size = 3)
{
if (nullptr == arr || i < 0) return;
if (size < 2) size = 2;
char buf[21] = "━━━━━━━━━━";
buf[2 * (size+1)] = '\0';
char space[] = "";
// 打印数组元素
printf("%s元素为: \n┎━", table_name);//┏┻
for (int id = i; id <= n; id++) printf("%s┒", buf);
printf("\n┃%*s下标", size-2, space);
for (int id = i; id < n; id++) printf("┃ %*d", size, id);
printf("┃\n┃━");
for (int id = i; id < n; id++) printf("%s╂", buf);
printf("%s┃\n┃%*s元素", buf, size-2, space);
for (int id = i; id < n; id++) printf("┃ %*d", size, arr[id]);
printf("┃\n┖━");
for (int id = i; id < n; id++) printf("%s┸", buf);
printf("%s┚\n", buf);
}
int main(void)
{
int arr[] = { 1,2,3,4,5,6,7,8,9,10,9,-8 }; /* 数组 ,100, 124,125 */
constexpr int n = sizeof(arr) / sizeof(arr[0]);
int tree[n * 4] = { 0 }; /* 线段树需要开4倍区间不会越界 */
build_tree(arr, tree, 0, 0, n-1); /* 构建线段树 */
//arr[4] = 6;
//update_tree(arr, tree, 0, 0, n-1, 4, 6); /* 修改数据: arr[4] = 6 */
update_tree(arr, tree, 0, 0, n - 1, 0, 6);
Print_table_frame(arr, 0, n, "arr数组");
Print_table_frame(tree, 0, max_node+1, "tree", 3);
int i = 6, j = 6;
int res = query_tree(arr, tree, 0, 0, n-1, i, j); /* 求i~j的和*/
int sum = 0;
for (; i <= j; ++i) sum += arr[i];
printf("%d , %d \n", sum, res);
return 0;
}
为了便于使用,我们将上述代码封装成类的形式。
// 使用类封装,可以继承,或包含使用。
// 以表格的形式打印数据
void Print_table_frame(const vector<int>& arr, const char* table_name = "arr数组", int size = 3)
{
if (arr.empty()) return;
if (size < 2) size = 2;
char buf[21] = "━━━━━━━━━━";
buf[2 * (size + 1)] = '\0';
char space[] = "";
int i = 0;
int n = arr.size();
// 打印数组元素
printf("%s元素为: \n┎━", table_name);//┏┻
for (int id = i; id <= n; id++) printf("%s┒", buf);
printf("\n┃%*s下标", size - 2, space);
for (int id = i; id < n; id++) printf("┃ %*d", size, id);
printf("┃\n┃━");
for (int id = i; id < n; id++) printf("%s╂", buf);
printf("%s┃\n┃%*s元素", buf, size - 2, space);
for (int id = i; id < n; id++) printf("┃ %*d", size, arr[id]);
printf("┃\n┖━");
for (int id = i; id < n; id++) printf("%s┸", buf);
printf("%s┚\n", buf);
}
class SegmentTree {
private:
vector<int>& arr; /* 原数组 */
vector<int> tree; /* 线段树数组 */
int max_id; /* 线段树长度 ,初始长度为原数组长度 */
public:
SegmentTree(vector<int>& vec)
:arr(vec), /* 引用该数组 */
tree(vec.size()*4, 0),
max_id(vec.size() - 1)
{
build_tree(0, 0, max_id); /* 构建线段树 */
tree.resize(max_id + 1); /* tree真正的长度*/
}
int query(int L, int R)
{
if (L == R) return arr[L];
return query_tree(0, 0, arr.size() - 1, L, R);
}
void update(int index, int val)
{
if (index < 0 || index >= arr.size()) return;
update_tree(0, 0, arr.size() - 1, index, val);
}
const vector<int>& getTreeArr() const
{
return tree;
}
private:
void build_tree(int node, int start, int end)
{
if (start == end)
{ // 叶子节点,存放arr中的值
tree[node] = arr[start];
// 记录,线段树最大长度
max_id = std::max(max_id, node);
}
else {
int mid = (start + end) / 2;
int left_node = 2 * node + 1;
int right_node = 2 * node + 2;
build_tree(left_node, start, mid);
build_tree(right_node, mid + 1, end);
tree[node] = tree[left_node] + tree[right_node]; // 非叶子节点,存放下边两孩子的和
}
}
int query_tree(int node, int start, int end, int L, int R)
{
// 所求区间在当前区间左边,或右边。则无需继续向下遍历
if (start > R || end < L)
return 0;
// 优化,当前访问区间在,[L,R]内部,则无需向下继续寻找了
else if (L <= start && end <= R) {
return tree[node];
}
//if (L == R) return arr[L];
else if (start == end) {
return tree[node];
}
int mid = (start + end) / 2;
int left_node = 2 * node + 1;
int right_node = 2 * node + 2;
int sum_left = query_tree(left_node, start, mid, L, R);
int sum_right = query_tree(right_node, mid + 1, end, L, R);
return sum_left + sum_right;
}
void update_tree(int node, int start, int end, int id, int val)
{// start==end -》 叶子结点, node 表示该节点在 tree中的下标位置
if (start == end) {
arr[id] = val;
tree[node] = val;
}
else {
int mid = (start + end) / 2;
int left_node = 2 * node + 1;
int right_node = 2 * node + 2;
// 在左边区间内,向左子树查找
if (id >= start && id <= mid) {
update_tree(left_node, start, mid, id, val);
}
else if (id <= end) {
update_tree(right_node, mid + 1, end, id, val);
}
tree[node] = tree[left_node] + tree[right_node]; // 更新上层节点的值
}
}
};
int main(void)
{
vector<int> arr = { 1,2,3,4,5,6,7,8,9,10,9,-8 };
SegmentTree tree(arr);
Print_table_frame(arr, "arr数组");
Print_table_frame(tree.getTreeArr(), "tree数组");
int i = 6, j = 6;
int res = tree.query(i, j); /* 求i~j的和*/
int sum = 0;
for (; i <= j; ++i) sum += arr[i];
printf("%d , %d \n", sum, res);
tree.update(6, 100);
Print_table_frame(arr, "arr数组");
Print_table_frame(tree.getTreeArr(), "tree数组");
i = 6, j = 7;
res = tree.query(i, j); /* 求i~j的和*/
sum = 0;
for (; i <= j; ++i) sum += arr[i];
printf("%d , %d \n", sum, res);
return 0;
}
练习:307. 区域和检索 - 数组可修改
https://leetcode-cn.com/problems/range-sum-query-mutable/submissions/
class SegmentTree; /* 自实现的线段树,参考👆 */
/* 包含的方式使用:在 NumArray 类中包含 SegmentTree对象,然后使用SegmentTree对象的方法。 */
class NumArray {
SegmentTree tree;
public:
NumArray(vector<int>& nums)
:tree(nums)
{
}
void update(int index, int val) {
tree.update(index, val);
}
int sumRange(int left, int right) {
return tree.query(left, right);
}
};
/* 继承的方式使用:继承SegmentTree类,就可以使用SegmentTree的方法了。 */
class NumArray: protected SegmentTree {
/* 这里使用保护继承,因此,SegmentTree类中的方法在NumArray类中都是保护的。这样就可以避免NumArrayd的类对象在类外部使用SegmentTree 的方法。只暴露NumArray本身提供的方法。*/
public:
NumArray(vector<int>& nums)
:SegmentTree(nums)
{
}
void update(int index, int val) {
SegmentTree::update(index, val);
}
int sumRange(int left, int right) {
return SegmentTree::query(left, right);
}
};
/**
* Your NumArray object will be instantiated and called as such:
* NumArray* obj = new NumArray(nums);
* obj->update(index,val);
* int param_2 = obj->sumRange(left,right);
*/