zl程序教程

您现在的位置是:首页 >  其他

当前栏目

P3384——树链剖分&&模板

amp模板 剖分 树链
2023-09-27 14:27:45 时间

题目描述

链接

如题,已知一棵包含N个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:

操作1: 格式: 1 x y z 表示将树从x到y结点最短路径上所有节点的值都加上z

操作2: 格式: 2 x y 表示求树从x到y结点最短路径上所有节点的值之和

操作3: 格式: 3 x z 表示将以x为根节点的子树内所有节点值都加上z

操作4: 格式: 4 x 表示求以x为根节点的子树内所有节点值之和

解决方法

用链式前向星的方式保存树,两次DFS将树剖分成若干重链和轻链,套用线段树进行更新和查询,对子树的操作可以转化成连续节点间的操作(因为DFS时子树节点的编号也是连续的),注意取模和开$long \ \ long$.

而且单独$add$标记时是不用下推的,只需查询时累加即可(不知道为什么那些题解都用下推的)

  1 #include<bits/stdc++.h>
  2 using namespace std;
  3 
  4 typedef long long ll;
  5 #define lc o <<1
  6 #define rc o <<1 | 1
  7 //const ll INF = 0x3f3f3f3f;
  8 const int maxn = 200000 + 10;
  9 struct Edge
 10 {
 11     int to, next;
 12 }edges[2*maxn];
 13 int head[maxn];
 14 int cur, f[maxn], deep[maxn], size[maxn], son[maxn], rk[maxn], id[maxn], top[maxn], cnt;
 15 int n, q, w[maxn], root, mod;
 16 
 17 inline void addedge(int u, int v)
 18 {
 19     ++cur;
 20     edges[cur].next = head[u];
 21     head[u] = cur;
 22     edges[cur].to = v;
 23 }
 24 
 25 struct SegTree{
 26     ll sum[maxn << 2], addv[maxn << 2];
 27     void build(int o, int l, int r)
 28     {
 29         if(l == r)
 30         {
 31             sum[o] = w[rk[l]] % mod;
 32         }
 33         else
 34         {
 35             int mid = (l + r) >> 1;
 36             build(lc, l, mid);
 37             build(rc, mid+1, r);
 38             sum[o] = (sum[lc] + sum[rc]) % mod;
 39         }
 40     }
 41 
 42     void maintain(int o, int l, int r)
 43     {
 44         if(l == r)  //如果是叶子结点
 45             sum[o] = w[rk[l]] % mod;
 46         else     //如果是非叶子结点
 47             sum[o] = (sum[lc] + sum[rc]) % mod;
 48 
 49         sum[o] = (sum[o] + addv[o] * (r-l+1)) % mod;
 50     }
 51     //区间修改,[cl,cr] += v;
 52     void update(int o, int l, int r, int cl, int cr, int v)  //
 53     {
 54         if(cl <= l && r <= cr)  addv[o] = (addv[o] + v) % mod;
 55         else
 56         {
 57             int m = l + (r-l) /2;
 58             if(cl <= m)  update(lc, l, m, cl, cr, v);
 59             if(cr > m)  update(rc, m+1, r, cl, cr, v);
 60         }
 61         maintain(o, l, r);
 62     }
 63 
 64     //区间查询,sum{ql,qr}
 65     ll query(int o, int l,int r, ll add, int ql, int qr)
 66     {
 67         if(ql <= l && r <= qr)
 68         {
 69             //prllf("sum[o]:%d  %d*(%d-%d+1)\n", sum[o], add, r, l);
 70             return (sum[o] + add * (r-l+1)) % mod;  //tx  l-r+1
 71         }
 72         else
 73         {
 74             int  m = l + (r - l) / 2;
 75             ll ans = 0;
 76             add = (add + addv[o]) % mod;
 77             if(ql <= m)  ans = (ans + query(lc, l, m, add, ql, qr)) % mod;
 78             if(qr > m)  ans = (ans + query(rc, m+1, r, add, ql, qr)) % mod;
 79             return ans;
 80         }
 81     }
 82 }st;
 83 
 84 void dfs1(int u, int fa, int depth)  //当前节点、父节点、层次深度
 85 {
 86     //prllf("u:%d fa:%d depth:%d\n", u, fa, depth);
 87     f[u] = fa;
 88     deep[u] = depth;
 89     size[u] = 1;   //这个点本身的size
 90     for(int i = head[u];i;i = edges[i].next)
 91     {
 92         int v = edges[i].to;
 93         if(v == fa)  continue;
 94         dfs1(v, u, depth+1);
 95         size[u] += size[v];   //子节点的size已被处理,用它来更新父节点的size
 96         if(size[v] > size[son[u]])  son[u] = v;    //选取size最大的作为重儿子
 97     }
 98 }
 99 
100 void dfs2(int u, int t)  //当前节点、重链顶端
101 {
102     //prllf("u:%d t:%d\n", u, t);
103     top[u] = t;
104     id[u] = ++cnt;   //标记dfs序
105     rk[cnt] = u;     //序号cnt对应节点u
106     if(!son[u])  return;   //没有儿子?
107     dfs2(son[u], t);  //我们选择优先进入重儿子来保证一条重链上各个节点dfs序连续
108 
109     for(int i = head[u];i;i = edges[i].next)
110     {
111         int v = edges[i].to;
112         if(v != son[u] && v != f[u])  dfs2(v, v);  //这个点位于轻链顶端,那么它的top必然为它本身
113     }
114 }
115 
116 
117 
118 /*修改和查询的原理是一致的,以查询操作为例,其实就是个LCA,不过这里要使用top数组加速,因为top可以直接跳到该重链的起始顶点*/
119 /*注意,每次循环只能跳一次,并且让结点深的那个跳到top的位置,避免两者一起跳而插肩而过*/
120 ll querysum(int x, int y)
121 {
122     int fx = top[x], fy = top[y];
123     ll ans = 0;
124     while(fx != fy)   //当两者不在同一条重链上
125     {
126         if(deep[fx] >= deep[fy])
127         {
128             //prllf("%d %d\n", id[fx], id[x]);
129             ans = (ans + st.query(1, 1, n, 0, id[fx], id[x])) % mod;   //线段树区间求和,计算这条重链的贡献
130             x = f[fx]; fx = top[x];
131         }
132         else
133         {
134             //prllf("%d %d\n", id[fy], id[y]);
135             ans = (ans + st.query(1, 1, n, 0, id[fy], id[y])) % mod;
136             y = f[fy]; fy = top[y];
137         }
138     }
139 
140     //循环结束,两点位于同一重链上,但两者不一定为同一点,所以还要加上这两点之间的贡献
141     if(id[x] <= id[y])
142     {
143         //prllf("%d %d\n", id[x], id[y]);
144         ans = (ans + st.query(1, 1, n, 0, id[x], id[y])) % mod;
145     }
146     else
147     {
148         //prllf("%d %d\n", id[y], id[x]);
149         ans = (ans + st.query(1, 1, n, 0, id[y], id[x])) % mod;
150     }
151     return ans;
152 }
153 
154 void update_add(int x, int y, int add)
155 {
156     int fx = top[x], fy = top[y];
157     while(fx != fy)   //当两者不在同一条重链上
158     {
159         if(deep[fx] >= deep[fy])
160         {
161             st.update(1, 1, n, id[fx], id[x], add);
162             x = f[fx]; fx = top[x];
163         }
164         else
165         {
166             st.update(1, 1, n, id[fy], id[y], add);
167             y = f[fy]; fy = top[y];
168         }
169     }
170     //循环结束,两点位于同一重链上,但两者不一定为同一点,所以还要加上这两点之间的贡献
171     if(id[x] <= id[y])  st.update(1, 1, n, id[x], id[y], add);
172     else  st.update(1, 1, n, id[y], id[x], add);
173 }
174 
175 
176 int main()
177 {
178     scanf("%d%d%d%d", &n, &q, &root, &mod);
179     for(int i = 1;i <= n;i++)
180     {
181         scanf("%d", &w[i]);
182         w[i] %= mod;
183     }
184     for(int i = 1;i < n;i++)
185     {
186         int u, v;
187         scanf("%d%d", &u, &v);
188         addedge(u, v);
189         addedge(v, u);
190     }
191     dfs1(root, -1, 1);
192     dfs2(root, root);
193 
194 //    for(ll i = 1;i <= n;i++)  prllf("%d  ", id[i]);
195 //    prllf("\n");
196 //    for(ll i = 1;i <= n;i++)  prllf("%d  ", rk[i]);
197 //    prllf("\n");
198 
199     st.build(1, 1, n);
200     //scanf("%d", &q);
201     while(q--)
202     {
203         int op;
204         scanf("%d", &op);
205         if(op == 1)
206         {
207             int u, v, add;
208             scanf("%d%d%d", &u, &v, &add);
209             update_add(u, v,  add);
210         }
211         else if(op == 2)
212         {
213             int u, v;
214             scanf("%d%d", &u, &v);
215             printf("%lld\n", querysum(u, v));
216         }
217         else if(op == 3)
218         {
219             int u, add;
220             scanf("%d%d", &u, &add);
221             st.update(1, 1, n, id[u], id[u]+size[u]-1, add);
222         }
223         else
224         {
225             int u;
226             scanf("%d", &u);
227             printf("%lld\n",st.query(1, 1, n, 0, id[u], id[u]+size[u]-1));
228         }
229         //st.prll_debug(1, 1, n);
230     }
231 }