zl程序教程

您现在的位置是:首页 >  工具

当前栏目

Splay平衡树 学习笔记

笔记学习 平衡
2023-06-13 09:12:47 时间

0 前言

突然想学Link Cut Tree了(当然不是因为我们班里有一个叫LCT的人 但是LCT有一个非常重要的前置知识,那就是Splay,于是就有了此篇文章。

1 Splay原理

BST

要想理解splay的原理,就得先理解BST。 二叉查找树(Binary Search Tree,简称BST)是一棵二叉树,它的左子节点的值比父节点的值要小,右节点的值要比父节点的值大。它的高度决定了它的查找效率。 比如这个就是一棵二叉查找树:

但是如果这棵二叉树变得丑陋点,就成了这样:

于是最坏查询情况就变成了O(N)这就尴尬了。

Splay

那么怎么解决如上所示的问题呢? 于是就变成了各种树。 其中有一位大佬叫Tarjan(怎么又是他 发明了Splay 那么Splay是怎么解决这个问题的呢? Tarjan想出了旋转。

2 Splay详解

Rotate

如图,我们有一棵二叉树,X,Y,Z分别代表三个节点,A,B,C分别代表三个子树。

现在,我们要把这棵二叉树的X节点转到Y节点的位置。 因为XY的左儿子,所以X<YY必定是X的右儿子。 因为YZ的左儿子,所以Y<ZX<ZX转到Y,旋转后X必定是Z的左儿子。 因为X的子树及X本身构成了Y的左儿子,所以X的子树及X本身一定<YX<B<YB旋转后是Y的左儿子。 因为CY的右儿子,所以C>YC一定是Y的右儿子。 而X的左儿子A是最小的,所以不管他,旋转后A还是X的左儿子。 检查一遍:A<X<B<Y<C<Z4种情况。

图画完了,我们可以总结下规律了。

  1. X旋转后到Y的位置。
  2. Y旋转后到X原来在Y的那个儿子的相对的儿子(如果X原来是Y的左儿子,Y旋转后就是X的右儿子)。
  3. Y的另一个不是X的儿子不变,X的原来XY的方向的儿子不变(如果X原来是Y的左儿子,X的左儿子就不变,Y的右儿子不变)。
  4. X的原来XY的方向相对的儿子旋转后变成了原来XY方向上的Y的儿子(如果X原来是Y的左儿子,X的右儿子就变成了Y的左儿子)。
inline void rotate(int x){
    re int y=tr[x].fa,z=tr[y].fa,k=tr[y].ch[1]==x;//y为x的父亲,z为y的父亲,x是y的哪个儿子 0 是左儿子,1 是右儿子
    tr[z].ch[tr[z].ch[1]==y]=x;//1.
    tr[x].fa=z;//x的父亲变为z
    tr[y].ch[k]=tr[x].ch[k^1];//4.
    tr[tr[x].ch[k^1]].fa=y;//更新父节点
    tr[x].ch[k^1]=y;//2.
    tr[y].fa=x;//更新父节点
    upd(y);upd(x);//更新每个点的数据
}

Splay

接下来考虑下一个问题:怎样把一个节点旋转到根节点呢?(比如上文的X旋转到Z) 先把X转到Y,再把Y转到Z?显然这是不行的,可以自己动手画一画,在某些情况下某条链可能仍然存在,这种情况下,Splay极有可能会被卡。 图我就不画了(懒 总结在这:

  1. XY分别是YZ的同一个儿子(如XY的左儿子,YZ的左儿子),先旋转Y,再旋转X
  2. XY分别是YZ的不同儿子(如XY的左儿子,YZ的右儿子),对X旋转2次。
inline void splay(int x,int rt){
    while(tr[x].fa!=rt){//直到把x转成rt的儿子
        re int y=tr[x].fa,z=tr[y].fa;//y,z分别为x的父节点、祖节点
        if(z!=rt)//如果z不是根节点,分两类旋转
        (tr[z].ch[0]==y)^(tr[y].ch[0]==x)?rotate(x):rotate(y);//分类
        rotate(x);
    }
    if(rt==0) root=x;//如果rt=0,把根节点更新为x
}

剩下的操作

剩下的操作和普通的BST差不多,这里就不再介绍。

3 Splay Code

Luogu P3369 【模板】普通平衡树

#include<algorithm>
#include<bitset>
#include<complex>
#include<deque>
#include<exception>
#include<fstream>
#include<functional>
#include<iomanip>
#include<ios>
#include<iosfwd>
#include<iostream>
#include<istream>
#include<iterator>
#include<limits>
#include<list>
#include<locale>
#include<map>
#include<memory>
#include<new>
#include<numeric>
#include<ostream>
#include<queue>
#include<set>
#include<sstream>
#include<stack>
#include<stdexcept>
#include<streambuf>
#include<string>
#include<typeinfo>
#include<utility>
#include<valarray>
#include<vector>
#include<cctype>
#include<cerrno>
#include<cfloat>
#include<ciso646>
#include<climits>
#include<clocale>
#include<cmath>
#include<csetjmp>
#include<csignal>
#include<cstdarg>
#include<cstddef>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<ctime>
using namespace std;

#define re  
#define int long long 

class Quick_Input_Output{
    private:
        static const int S=1<<21;
        #define tc() (A==B&&(B=(A=Rd)+fread(Rd,1,S,stdin),A==B)?EOF:*A++)
        char Rd[S],*A,*B;
        #define pc putchar
    public:
        #undef gc
        #define gc getchar 
        inline int read(){
            int res=0,f=1;char ch=gc();
            while(ch<'0'ch>'9'){if(ch=='-') f=-1;ch=gc();}
            while(ch>='0'&&ch<='9') res=res*10+ch-'0',ch=gc();
            return res*f;
        }
        inline void write(int x){
            if(x<0) pc('-'),x=-x;
            if(x<10) pc(x+'0');
            else write(x/10),pc(x%10+'0');
        }
        #undef gc
        #undef pc
}I;
#define File freopen("tmp.in","r",stdin);freopen("tmp.out","w",stdout);

class Splay{
//  private:
    public:
        int root,tot;
        struct Tree{
            int fa,ch[2],val,cnt,size;
        }tr[100010];
        inline void upd(int x){
            tr[x].size=tr[tr[x].ch[0]].size+tr[tr[x].ch[1]].size+tr[x].cnt;
        }
        inline void rotate(int x){
            re int y=tr[x].fa,z=tr[y].fa,k=tr[y].ch[1]==x;
            tr[z].ch[tr[z].ch[1]==y]=x;
            tr[x].fa=z;
            tr[y].ch[k]=tr[x].ch[k^1];
            tr[tr[x].ch[k^1]].fa=y;
            tr[x].ch[k^1]=y;
            tr[y].fa=x;
            upd(y);upd(x);
        }
        inline void splay(int x,int rt){
            while(tr[x].fa!=rt){
                re int y=tr[x].fa,z=tr[y].fa;
                if(z!=rt) (tr[z].ch[0]==y)^(tr[y].ch[0]==x)?rotate(x):rotate(y);
                rotate(x);
            }
            if(rt==0) root=x;
        }
        inline void find(int x){
            re int u=root;
            if(!u) return ;
            while(tr[u].ch[x>tr[u].val]&&x!=tr[u].val) u=tr[u].ch[x>tr[u].val]; 
            splay(u,0);
        }
        inline void insert(int x){
            re int u=root,ff=0;
            while(u&&tr[u].val!=x){
                ff=u;
                u=tr[u].ch[x>tr[u].val];
            }
            if(u) tr[u].cnt++;
            else{
                u=++tot;
                if(ff) tr[ff].ch[x>tr[ff].val]=u;
                tr[u].ch[0]=tr[u].ch[1]=0;
                tr[u].fa=ff;
                tr[u].val=x;
                tr[u].cnt=1;
                tr[u].size=1;
            }
            splay(u,0);
        }
        inline int pre(int x){
            find(x);
            re int u=root;
            if(tr[u].val<x) return u;
            u=tr[u].ch[0];
            while(tr[u].ch[1]) u=tr[u].ch[1];
            return u;
        }
        inline int nxt(int x){
            find(x);
            re int u=root;
            if(tr[u].val>x) return u;
            u=tr[u].ch[1];
            while(tr[u].ch[0]) u=tr[u].ch[0];
            return u;
        }
        inline void del(int x){
            re int Pre=pre(x),Nxt=nxt(x);
            splay(Pre,0);splay(Nxt,Pre);
            re int Del=tr[Nxt].ch[0];
//          cout<<"deling "<<x<<" "<<Del<<' '<<tr[Del].cnt<<endl;
            if(tr[Del].cnt>1) tr[Del].cnt--,splay(Del,0);
            else tr[Nxt].ch[0]=0;
        }
        inline int rank(int x){
            re int u=root;
            if(tr[u].size<x) return 0;
            while(1){
                re int y=tr[u].ch[0];
                if(x>tr[y].size+tr[u].cnt) x-=tr[y].size+tr[u].cnt,u=tr[u].ch[1];
                else if(tr[y].size>=x) u=y;
                else return tr[u].val;
            }
        }
        inline int Rank(int x){
            find(x);
            return tr[tr[root].ch[0]].size;
        }
        inline void PrintSplay(){
            cout<<"Now Root = "<<root<<endl;
            for(int i=1;i<=tot;i++){
                cout<<"Node id:"<<i<<" Fa:"<<tr[i].fa<<" CHL:"<<tr[i].ch[0]<<" CHR:"<<tr[i].ch[1]<<" "<<tr[i].cnt<<" "<<tr[i].val<<" sz:"<<tr[i].size<<endl;
            }
        }
}T;
int n;
signed main(){
//  freopen("input.txt","r",stdin);
    T.root=0;T.tot=0;
    T.insert(-2147483647);
    T.insert(2147483647);
    n=I.read();
    for(int op,x,i=1;i<=n;i++){
        op=I.read();x=I.read();
        if(op==1){
            T.insert(x);
        }else if(op==2){
            T.del(x); 
        }else if(op==3){
            I.write(T.Rank(x));putchar('\n');
        }else if(op==4){
            I.write(T.rank(x+1));putchar('\n');
        }else if(op==5){
            I.write(T.tr[T.pre(x)].val);putchar('\n');
        }else if(op==6){
            I.write(T.tr[T.nxt(x)].val);putchar('\n');
        }
//      cout<<"Round "<<i<<"\n";
//      T.PrintSplay();
    }
}