zl程序教程

您现在的位置是:首页 >  后端

当前栏目

树链剖分 算法学习

2023-06-13 09:12:49 时间

树链剖分 算法学习

树你应该懂的吧o( ̄︶ ̄)o 学习树链剖分之前需要先学习:dfs、线段树(当然大佬们用树状数组代替线段树也可以O(∩_∩)O),据说一名普及+的oier应该都会呀

先来了解树链剖分的用处

Luogu题目传送门 已知一棵包含N个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:

  • 操作1: 格式: 1 x y z 表示将树从x到y结点最短路径上所有节点的值都加上z
  • 操作2: 格式: 2 x y 表示求树从x到y结点最短路径上所有节点的值之和
  • 操作3: 格式: 3 x z 表示将以x为根节点的子树内所有节点值都加上z
  • 操作4: 格式: 4 x 表示求以x为根节点的子树内所有节点值之和

如果直接暴力的话,肯定会TLE(废话)。所以这时候,树链剖分闪亮登场。

什么是树链剖分

一种算法(废话),它通过分轻重边把树分割成很多链,然后利用某种数据结构维护这些链(比如上文提到的线段树、树状数组等)但前提是这种数据结构支持动态修改(你别给我整个RMQ)。本质上是一种暴力算法。 PS:树剖的复杂度约为O(nlog^2n)

树链剖分的基本概念

名称

概念

重儿子

父亲节点的所有儿子中子节点数目最多(sz最大)的节点

轻儿子

父亲节点除了重儿子以外的儿子

重边

父亲节点和重儿子连成的边

轻边

父亲节点和轻儿子连成的边

重链

由多条重边连成的路径

轻链

由多条轻边连成的路径

没看懂?没关系,结合下面这张图:(红色的边代表重边,黑色的边代表轻边,红色的节点代表重儿子,黑色的节点代表轻儿子) PS:这里默认树根也是重儿子。

上图的重链有:1-4,3-6。

变量声明

ll fir[MAXN],nxt[MAXN*2],son[MAXN*2],w[MAXN*2],tot;
struct Node{
    ll sum,tag,l,r,ls,rs;
}a[2*MAXN];
ll root,n,m,r,mod,v[MAXN],cnt,fa[MAXN],dep[MAXN],sz[MAXN],c[MAXN],rk[MAXN],top[MAXN],id[MAXN];

名称

作用

fir_x

关于x的最后一条边编号

nxt_x

关于x的上一条边编号

son_x

x条边的连向

w_x

其实没啥用,打着习惯了

a_x.ls

编号为x的节点的左儿子

a_x.rs

编号为x的节点的右儿子

fa_x

编号为x的节点的父亲

c_x

编号为x的节点的重儿子

rk_x

当前dfs标号在树中所对应的节点的编号

top_x

编号为x的节点所在链的顶端节点编号

id_x

编号为x的节点dfs后的新编号

dep_x

编号为x的节点的深度

sz_x

以编号为x的节点为根的子树的节点个数

树链剖分的实现

第一次$dfs$求出每个节点的重儿子、父亲、深度、子树大小。

PS:如果一个点的多个儿子所在子树大小相等且最大,那随便找一个当做它的重儿子就好了,叶节点没有重儿子,非叶节点有且只有一个重儿子。

inline void dfs1(ll x,ll f,ll deep){
    fa[x]=f;//该节点的父亲
    dep[x]=deep;//该节点深度
    sz[x]=1;//该节点子树先设置为1(本身)
    for(ll i=fir[x];i;i=nxt[i]){//寻找与该节点相连的边
        ll to=son[i];//该边的另一个节点
        if(to==f) continue ;//如果另一个节点刚好是父亲,那么continue 
        dfs1(to,x,deep+1);//否则dfs该节点,并且父亲为本节点,深度+1
        sz[x]+=sz[to];//子树大小增加
        if(sz[to]>sz[c[x]]) c[x]=to;//重儿子更新(找子树最大的)
    }
}
//主函数调用
dfs1(root,0,1);

操作完以后应该是下图:

第二次$dfs$求出每个节点的链顶端节点、新编号、$dfs$编号对应的节点编号。

inline void dfs2(ll x,ll ttop){
    top[x]=ttop;//链顶端编号
    id[x]=++cnt;//新编号(dfs序)
    rk[cnt]=x;//新编号对应节点编号
    if(c[x]!=0) dfs2(c[x],ttop);//如果不是叶子节点,优先dfs重儿子,因为节点与重儿子处在同一重链,所以重儿子的重链顶端还是ttop
    for(ll i=fir[x];i;i=nxt[i]){
        ll to=son[i];
        if(to!=c[x]&&to!=fa[x]) dfs2(to,to);//如果既不是父亲也不是重儿子,那么就是该节点的轻儿子,那么dfs,且该节点的重链顶端为它本身
    }
}
//主函数调用
dfs2(root,root);

操作完以后应该是下图:

线段树等数据结构的维护

接下来就是线段树、树状数组等数据结构的维护了,具体使用哪种数据结构因题目而异,这里提供模板题(上文介绍的题目)所使用的线段树(区间修改、区间询问)。

inline void pushup(ll x){
    a[x].sum=(a[a[x].ls].sum+a[a[x].rs].sum)%mod;//更新求和
}
inline void build(ll l,ll r,ll x){
    if(l==r){
        a[x].sum=v[rk[l]];//符合所在区间,更新
        a[x].l=a[x].r=l;//l、r更新
        return ;
    }
    ll mid=l+r>>1;//线段树性质
    a[x].ls=cnt++;a[x].rs=cnt++;//左右儿子节点编号
    build(l,mid,a[x].ls);build(mid+1,r,a[x].rs);//分而治之
    a[x].l=a[a[x].ls].l,a[x].r=a[a[x].rs].r;//区间更新
    pushup(x);//sum更新
}
inline ll len(ll x){
    return a[x].r-a[x].l+1;//该区间的节点数量
}
inline void pushdown(ll x){
    if(a[x].tag!=0){//如果有lazy tag
        a[a[x].ls].tag+=a[x].tag;a[a[x].rs].tag+=a[x].tag;//向左右儿子传递
        a[a[x].ls].tag%=mod;a[a[x].rs].tag%=mod;
        a[a[x].ls].sum+=a[x].tag*len(a[x].ls);a[a[x].rs].sum+=a[x].tag*len(a[x].rs);//左右儿子更新
        a[a[x].ls].sum%=mod;a[a[x].rs].sum%=mod;
        a[x].tag=0;//lazy tag取消
    }
}
inline void update(ll l,ll r,ll c,ll x){
    if(a[x].l>=l&&a[x].r<=r){
        a[x].tag+=c;a[x].tag%=mod;//修改lazy tag
        a[x].sum+=len(x)*c;a[x].sum%=mod;//修改sum
        return ;
    }
    pushdown(x);//标记下传
    ll mid=a[x].l+a[x].r>>1;
    if(mid>=l) update(l,r,c,a[x].ls);//分而治之
    if(mid<r) update(l,r,c,a[x].rs);
    pushup(x);//更新sum
}
inline ll query(ll l,ll r,ll x){
    if(a[x].l>=l&&a[x].r<=r) return a[x].sum;//如果符合在本区间内,那么return
    pushdown(x);//标记下传
    ll mid=a[x].l+a[x].r>>1,ss=0;
    if(mid>=l) ss+=query(l,r,a[x].ls);ss%=mod;//分而治之
    if(mid<r) ss+=query(l,r,a[x].rs);ss%=mod;
    return ss;//返回
}
//主函数调用(根据上文题目)
cnt=0;build(1,n,root=cnt++);
update(id[x],id[x]+sz[x]-1,y,root);
query(id[x],id[x]+sz[x]-1,root);

根据题目需要添加操作

就比如上文的题目中还要求的操作:

  • 操作1: 格式: 1 x y z 表示将树从x到y结点最短路径上所有节点的值都加上z
  • 操作2: 格式: 2 x y 表示求树从x到y结点最短路径上所有节点的值之和

与操作3、操作4不同,这里要求的是一条路径上的节点,而没有告诉我们节点的编号,所以,我们这时要求出节点编号:

inline ll Query(ll x,ll y){
    ll res=0;
    while(top[x]!=top[y]){//若两点不再同一条链上
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        res+=query(id[top[x]],id[x],root);//ans更新
        res%=mod;
        x=fa[top[x]];//让x向上爬(与倍增思想类似,但有时复杂度更低)
    }
    if(id[x]>id[y]) swap(x,y);
    res+=query(id[x],id[y],root);//在同一条链,跳到同一点,ans更新
    res%=mod;
    return res;
}
inline void Update(ll x,ll y,ll c){
    while(top[x]!=top[y]){//两点不在同一条链
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        update(id[top[x]],id[x],c,root);//更新
        x=fa[top[x]];//让x向上爬
    }
    if(id[x]>id[y]) swap(x,y);
    update(id[x],id[y],c,root);//在同一链,跳到同一点,更新
}

当然,还有一个操作是非常常用的,那就是求lca(最近公共祖先)。

inline ll lca(ll x,ll y){
    while(top[x]!=top[y]){//两点不在同一条链上肯定没有公共祖先
        if(dep[top[x]]>=dep[top[y]])x=fa[top[x]];//让深度低的点向上爬,x向上爬
        else y=fa[top[y]];//y向上爬
    }
    return dep[x]<dep[y]?x:y;//取深度低的点
}

模板题代码

对对对,就是上文提到的题目。

#include<bits/stdc++.h>
#define MAXN 200010
#define ll long long
using namespace std;
ll fir[MAXN],nxt[MAXN*2],son[MAXN*2],w[MAXN*2],tot;
struct Node{
    ll sum,tag,l,r,ls,rs;
}a[2*MAXN];
ll root,n,m,r,mod,v[MAXN],cnt,fa[MAXN],dep[MAXN],sz[MAXN],c[MAXN],rk[MAXN],top[MAXN],id[MAXN];
inline void dfs1(ll x,ll f,ll deep){
    fa[x]=f;
    dep[x]=deep;
    sz[x]=1;
    for(ll i=fir[x];i;i=nxt[i]){
        ll to=son[i];
        if(to==f) continue ;
        dfs1(to,x,deep+1);
        sz[x]+=sz[to];
        if(sz[to]>sz[c[x]]) c[x]=to;
    }
}
inline void dfs2(ll x,ll ttop){
    top[x]=ttop;
    id[x]=++cnt;
    rk[cnt]=x;
    if(c[x]!=0) dfs2(c[x],ttop);
    for(ll i=fir[x];i;i=nxt[i]){
        ll to=son[i];
        if(to!=c[x]&&to!=fa[x]) dfs2(to,to);
    }
}
inline void pushup(ll x){
    a[x].sum=(a[a[x].ls].sum+a[a[x].rs].sum)%mod;
}
inline void build(ll l,ll r,ll x){
    if(l==r){
        a[x].sum=v[rk[l]];
        a[x].l=a[x].r=l;
        return ;
    }
    ll mid=l+r>>1;
    a[x].ls=cnt++;a[x].rs=cnt++;
    build(l,mid,a[x].ls);build(mid+1,r,a[x].rs);
    a[x].l=a[a[x].ls].l,a[x].r=a[a[x].rs].r;
    pushup(x);
}
inline ll len(ll x){
    return a[x].r-a[x].l+1;
}
inline void pushdown(ll x){
    if(a[x].tag!=0){
        a[a[x].ls].tag+=a[x].tag;a[a[x].rs].tag+=a[x].tag;
        a[a[x].ls].tag%=mod;a[a[x].rs].tag%=mod;
        a[a[x].ls].sum+=a[x].tag*len(a[x].ls);a[a[x].rs].sum+=a[x].tag*len(a[x].rs);
        a[a[x].ls].sum%=mod;a[a[x].rs].sum%=mod;
        a[x].tag=0;
    }
}
inline void update(ll l,ll r,ll c,ll x){
    if(a[x].l>=l&&a[x].r<=r){
        a[x].tag+=c;a[x].tag%=mod;
        a[x].sum+=len(x)*c;a[x].sum%=mod;
        return ;
    }
    pushdown(x);
    ll mid=a[x].l+a[x].r>>1;
    if(mid>=l) update(l,r,c,a[x].ls);
    if(mid<r) update(l,r,c,a[x].rs);
    pushup(x);
}
inline ll lca(ll x,ll y){
    while(top[x]!=top[y]){
        if(dep[top[x]]>=dep[top[y]])x=fa[top[x]];
        else y=fa[top[y]];
    }
    return dep[x]<dep[y]?x:y;
}
inline ll query(ll l,ll r,ll x){
    if(a[x].l>=l&&a[x].r<=r) return a[x].sum;
    pushdown(x);
    ll mid=a[x].l+a[x].r>>1,ss=0;
    if(mid>=l) ss+=query(l,r,a[x].ls);ss%=mod;
    if(mid<r) ss+=query(l,r,a[x].rs);ss%=mod;
    return ss;
}
inline ll Query(ll x,ll y){
    ll res=0;
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        res+=query(id[top[x]],id[x],root);
        res%=mod;
        x=fa[top[x]];
    }
    if(id[x]>id[y]) swap(x,y);
    res+=query(id[x],id[y],root);
    res%=mod;
    return res;
}
inline void Update(ll x,ll y,ll c){
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        update(id[top[x]],id[x],c,root);
        x=fa[top[x]];
    }
    if(id[x]>id[y]) swap(x,y);
    update(id[x],id[y],c,root);
}
inline ll read(){
    char ch=getchar();ll res=0,f=1;
    while(ch<'0'ch>'9'){if(ch=='-') f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9') res=res*10+ch-'0',ch=getchar();
    return res*f;
}
inline void write(ll x){
    if(x<10) putchar(x+'0');
    else{
        write(x/10);
        putchar(x%10+'0');
    }
}
inline void add(ll x,ll y){
    ++tot;
    son[tot]=y;
    nxt[tot]=fir[x];
    fir[x]=tot;
}
int main(){
    n=read();m=read();r=read();mod=read();
    for(ll i=1;i<=n;i++) v[i]=read();
    for(ll x,y,i=1;i<n;i++){
        x=read(),y=read();
        add(x,y);add(y,x);
    }
    cnt=0;dfs1(r,0,1);
    dfs2(r,r);
    cnt=0;build(1,n,root=cnt++);
    for(ll op,x,y,k,i=1;i<=m;i++){
        op=read();
        if(op==1){
            x=read();y=read();k=read();
            Update(x,y,k);
        }else if(op==2){
            x=read();y=read();
            write(Query(x,y));putchar('\n');
        }else if(op==3){
            x=read();y=read();
            update(id[x],id[x]+sz[x]-1,y,root);
        }else if(op==4){
            x=read();
            write(query(id[x],id[x]+sz[x]-1,root));putchar('\n');
        }
    }
    return 0;
}

完美撒花✿✿ヽ(°▽°)ノ✿