zl程序教程

您现在的位置是:首页 >  其他

当前栏目

P7728 旧神归来 题解

2023-04-18 15:52:45 时间

日常生活:写多项式——写多项式题解——颓——写多项式——写多项式题解——颓——……

最近真的降智。大水题切不动。

#查询 gtm1514 精神状态

题解

好像挺清新的。

首先我们知道这是个多项式题。然后题目这个 b 东西不太好刻画,整点好搞的。

发现设 (f_i) 为当前有多少个深度为 (i) 的叶子比较好。

于是设出生成函数 (F)。假如当前最浅的叶子深度为 (d),那么一次变换可以写成:

[F_i=(1+x^d)F_{i-1}-x^d ]

那考虑怎么表示一下连续的几次操作。设一次消掉 (d) 深度的一个叶子,后一次消掉 (e) 深度的一个叶子,那么

[egin{aligned} &(1+x^e)left((1+x^d)F-x^d ight)-x^e\ =&(1+x^e)(1+x^d)F-x^d-x^e-x^{d+e}\ =&(1+x^e)(1+x^d)(F-1)+1 end{aligned} ]

初始的 (F_0) 可以搜一遍得到。显然任意深度有限的叶子都可以被消掉。那么设深度为 (i) 的叶子需要 (a_i) 步消掉,有

[(F-1)prod_{i=1}(1+x^i)^{a_i}+1=0 ]

[prod_{i=1}(1+x^i)^{a_i}=frac 1{1-F} ]

这东西如何处理我们是熟知的。取个 (ln) 得到:

[sum_{i=1}a_iln(1+x^i)=-ln(1-F) ]

[sum_{i=1}a_isum_{j=1}frac{(-1)^{j-1}x^{ij}}j=-ln(1-F) ]

右边一只 (log) 爆算。左边一只 (log) 爆算(当然我比较懒左边根号处理的)。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cmath>
using namespace std;
const int mod=998244353;
int n,m,wl;
struct gra{
    int v,w,next;
}edge[200010];
int t,head[100010],ans[100010],f[300010],w[300010],g[300010],inv[300010];
void Add(int u,int v){
    edge[++t].v=v;edge[t].next=head[u];head[u]=t;
}
void dfs(int x,int fa,int d){
    bool jud=false;
    for(int i=head[x];i;i=edge[i].next){
        if(edge[i].v!=fa)dfs(edge[i].v,x,d+1),jud=true;
    }
    if(!jud)g[d]++;
}
#define add(x,y) (x+y>=mod?x+y-mod:x+y)
#define sub(x,y) (x<y?x-y+mod:x-y)
void get(int n){
	wl=1;
	while(wl<n)wl<<=1;
}
int qpow(int a,int b){
    int ans=1;
    while(b){
        if(b&1)ans=1ll*ans*a%mod;
        a=1ll*a*a%mod;
        b>>=1;
    }
    return ans;
}
void init(int n){
    int t=1;
    while((1<<t)<n)t++;
    t=min(t-1,21);
    w[0]=1;w[1<<t]=qpow(31,1<<21-t);inv[1]=1;
    for(int i=t;i;i--)w[1<<i-1]=1ll*w[1<<i]*w[1<<i]%mod;
    for(int i=1;i<(1<<t);i++)w[i]=1ll*w[i&(i-1)]*w[i&-i]%mod;
	for(int i=2;i<=n;i++)inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod;
}
void DIF(int a[],int n){
    for(int mid=n>>1;mid>=1;mid>>=1){
        for(int i=0,k=0;i<n;i+=mid<<1,k++){
            for(int j=0;j<mid;j++){
                int x=1ll*a[i+j+mid]*w[k]%mod;
                a[i+j+mid]=sub(a[i+j],x);
                a[i+j]=add(a[i+j],x);
            }
        }
    }
}
void DIT(int a[],int n){
    for(int mid=1;mid<n;mid<<=1){
        for(int i=0,k=0;i<n;i+=mid<<1,k++){
            for(int j=0;j<mid;j++){
                int x=a[i+j+mid];
                a[i+j+mid]=1ll*sub(a[i+j],x)*w[k]%mod;
                a[i+j]=add(a[i+j],x);
            }
        }
    }
    int inv=qpow(n,mod-2);
    for(int i=0;i<n;i++)a[i]=1ll*a[i]*inv%mod;
    reverse(a+1,a+n);
}
#define mul(f,g,n) for(int i=0;i<n;i++)f[i]=1ll*f[i]*g[i]%mod
void getinv(int n,int f[],int g[]){
    get(n);
    static int tmp[300010],ret[300010];
    for(int i=0;i<wl;i++)g[i]=0;
    g[0]=qpow(f[0],mod-2);
    for(int len=2;len<=wl;len<<=1){
        memcpy(tmp,f,4*len);memcpy(ret,g,2*len);
        DIF(tmp,len);DIF(ret,len);mul(tmp,ret,len);
        DIT(tmp,len);
        memset(tmp,0,2*len);tmp[0]=mod-1;
        DIF(tmp,len);mul(ret,tmp,len);
        DIT(ret,len);
        for(int i=len>>1;i<len;i++)g[i]=sub(0,ret[i]);
    }
	for(int i=n;i<wl;i++)g[i]=0;
	for(int i=0;i<wl;i++)tmp[i]=ret[i]=0;
}
void dao(int f[],int n){
	for(int i=1;i<n;i++)f[i-1]=1ll*f[i]*i%mod;
	f[n-1]=0;
}
void jifen(int f[],int n){
	for(int i=n;i>=1;i--)f[i]=1ll*f[i-1]*inv[i]%mod;
	f[0]=0;
}
void getln(int n,int f[],int g[]){
    getinv(n,f,g);get(n<<1);
	for(int i=n;i<wl;i++)f[i]=g[i]=0;
    dao(f,wl);
    DIF(f,wl);DIF(g,wl);mul(g,f,wl);
    DIT(g,wl);
    jifen(g,wl);
	for(int i=n;i<wl;i++)g[i]=0;
}
int main(){
    scanf("%d%d",&n,&m);init(m+1<<1);
    for(int i=1;i<n;i++){
        int u,v;scanf("%d%d",&u,&v);
        Add(u,v);Add(v,u);
    }
    dfs(1,0,0);
    for(int i=0;i<=n;i++)g[i]=sub(0,g[i]);
    g[0]=add(g[0],1);
    getln(m+1,g,f);
    for(int i=1;i<=m;i++)f[i]=sub(0,f[i]);
    for(int k=1;k<=m;k++){
        ans[k]=f[k];
        for(int i=1;i*i<=k&&i<k;i++){
            if(k%i==0){
                int ret=1ll*inv[k/i]*ans[i]%mod;
                if((k/i)&1)ans[k]=sub(ans[k],ret);
                else ans[k]=add(ans[k],ret);
                if(i*i!=k&&i!=1){
                    ret=1ll*inv[i]*ans[k/i]%mod;
                    if(i&1)ans[k]=sub(ans[k],ret);
                    else ans[k]=add(ans[k],ret);
                }
            }
        }
    }
    for(int i=1;i<=m;i++){
        ans[i]=add(ans[i],ans[i-1]);
        printf("%d
",ans[i]);
    }
    return 0;
}