zl程序教程

您现在的位置是:首页 >  后端

当前栏目

【luogu P6800】Chirp Z-Transform(多项式)(NTT)(bluestein 算法)

算法 Luogu transform 多项式 NTT
2023-09-27 14:28:25 时间

Chirp Z-Transform

题目链接:luogu P6800

题目大意

给你一个多项式和 c,m,要你求把 c^0,c^1,…c,^m-1 分别带入多项式得到的值。

思路

考虑把答案也看做是多项式:
a n s i = F ( c i ) = ∑ j = 0 n − 1 a j c i j ans_i=F(c^i)=\sum\limits_{j=0}^{n-1}a_jc^{ij} ansi=F(ci)=j=0n1ajcij

然后又一个东西是: i j = ( i + j 2 ) − ( i 2 ) − ( j 2 ) ij=\binom{i+j}{2}-\binom{i}{2}-\binom{j}{2} ij=(2i+j)(2i)(2j)
简单证明:
( i + j 2 ) − ( i 2 ) − ( j 2 ) = ( i + j ) ( i + j − 1 ) − i ( i − 1 ) − j ( j − 1 ) 2 \binom{i+j}{2}-\binom{i}{2}-\binom{j}{2}=\dfrac{(i+j)(i+j-1)-i(i-1)-j(j-1)}{2} (2i+j)(2i)(2j)=2(i+j)(i+j1)i(i1)j(j1)
= i 2 + i j − i + i j + j 2 − j − i 2 + i − j 2 + j 2 = 2 i j 2 = i j =\dfrac{i^2+ij-i+ij+j^2-j-i^2+i-j^2+j}{2}=\dfrac{2ij}{2}=ij =2i2+iji+ij+j2ji2+ij2+j=22ij=ij

然后带入:
a n s i = F ( c i ) = ∑ j = 0 n − 1 a j c i j ans_i=F(c^i)=\sum\limits_{j=0}^{n-1}a_jc^{ij} ansi=F(ci)=j=0n1ajcij
= ∑ j = 0 n − 1 a j c ( i + j 2 ) − ( i 2 ) − ( j 2 ) =\sum\limits_{j=0}^{n-1}a_jc^{\binom{i+j}{2}-\binom{i}{2}-\binom{j}{2}} =j=0n1ajc(2i+j)(2i)(2j)
= c − ( i 2 ) ∑ j = 0 n − 1 a j c ( i + j 2 ) c − ( j 2 ) =c^{-\binom{i}{2}}\sum\limits_{j=0}^{n-1}a_jc^{\binom{i+j}{2}}c^{-\binom{j}{2}} =c(2i)j=0n1ajc(2i+j)c(2j)

然后发现右边这个部分可以直接卷起来,可以用 NTT 搞。

然后接着是求 c c c 的要用的次方项,每次都 log ⁡ \log log 太慢了,我们可以用预处理光速乘或者两边前缀和搞得出它。
两边阶乘的原理是 x ( x − 1 ) 2 = ( x − 1 ) ( ( x − 1 ) + 1 ) 2 = 1 + 2 + . . . + n − 1 \dfrac{x(x-1)}{2}=\dfrac{(x-1)((x-1)+1)}{2}=1+2+...+n-1 2x(x1)=2(x1)((x1)+1)=1+2+...+n1,然后就可以两边前缀和搞了。

你可以 a j c − ( j 2 ) a_jc^{-\binom{j}{2}} ajc(2j) 弄一个多项式(系数变成 n − j n-j nj),然后 c ( i + j 2 ) c^{\binom{i+j}{2}} c(2i+j) 弄一个多项式(系数就是这个),然后加了之后就是我们要的第 i i i 项了。(外面那个 c − ( i 2 ) c^{-\binom{i}{2}} c(2i) 最后输出的时候乘上即可)

代码

#include<cstdio>
#include<iostream>
#include<algorithm>
#define ll long long
#define mo 998244353

using namespace std;

int n, m, nm, limit, l_size;
int an[3000001];
ll a[3000001], c, invc, G, Gv;
ll jc[3000001], inv[3000001];
ll g[3000001];

ll ksm(ll x, ll y) {
	ll re = 1;
	while (y) {
		if (y & 1) re = re * x % mo;
		x = x * x % mo;
		y >>= 1;
	}
	return re;
}

void NTT(ll *now, int op) {//NTT
	for (int i = 0; i < limit; i++)
		if (i < an[i]) swap(now[i], now[an[i]]);
	
	for (int mid = 1; mid < limit; mid <<= 1) {
		ll Wn = ksm(op == 1 ? G : Gv, (mo - 1) / (mid << 1));
		for (int R = (mid << 1), j = 0; j < limit; j += R) {
			ll w = 1;
			for (int k = 0; k < mid; k++, w = w * Wn % mo) {
				ll x = now[j + k], y = w * now[j + mid + k] % mo;
				now[j + k] = (x + y) % mo;
				now[j + mid + k] = (x - y + mo) % mo;
			}
		}
	}
}

int main() {
	scanf("%d %lld %d", &n, &c, &m); nm = max(n, m);
	
	invc = ksm(c, mo - 2);//两次前缀和求出 n*(n-1)/2 次方的阶乘以及它的逆元
	jc[0] = 1;
	for (int i = 1; i <= n + m; i++) jc[i] = jc[i - 1] * c % mo;
	for (int i = 1; i <= n + m; i++) jc[i] = jc[i - 1] * jc[i] % mo;
	inv[0] = 1;
	for (int i = 1; i <= nm; i++) inv[i] = inv[i - 1] * invc % mo;
	for (int i = 1; i <= nm; i++) inv[i] = inv[i - 1] * inv[i] % mo;
	
	for (int i = 0; i < n; i++) scanf("%lld", &a[n - i]), a[n - i] = a[n - i] * (i ? inv[i - 1] : 1) % mo;
	for (int i = 0; i <= n + m; i++) g[i] = i ? jc[i - 1] : 1;
	
	limit = 1;
	while (limit <= n + m) {
		limit <<= 1;
		l_size++;
	}
	for (int i = 0; i < limit; i++)
		an[i] = (an[i >> 1] >> 1) | ((i & 1) << (l_size - 1));
	
	G = 3;
	Gv = ksm(G, mo - 2);
	
	NTT(a, 1); NTT(g, 1);
	for (int i = 0; i < limit; i++)
		a[i] = a[i] * g[i] % mo;
	NTT(a, -1);
	
	ll liv = ksm(limit, mo - 2);
	for (int i = 0; i < limit; i++) a[i] = a[i] * liv % mo;
	
	for (int i = 0; i < m; i++) {
		printf("%lld ", a[i + n] * (i ? inv[i - 1] : 1) % mo);
	}
	
	return 0;
}