zl程序教程

您现在的位置是:首页 >  其他

当前栏目

【luogu P5205】【模板】多项式开根(牛顿迭代)(多项式)

模板迭代 Luogu 多项式 牛顿
2023-09-27 14:28:25 时间

【模板】多项式开根

题目链接:luogu P5205

题目大意

给你一个多项式 A(x),要你找到一个多项式 B(x) 使得 B^2(x)=A(x) (mod x^n)。
如果有多个解选零次项系数最小的,然后保证 A(0)=1。

思路

在知道如果弄之前,我们要知道一个叫做牛顿迭代的东西。
那我们就要从泰勒展开开始,于是我们就讲讲求导,积分这些东西。

求导与积分

求导和积分是两个互逆的操作,先不收多项式的吧。
求导就是求图像在某个位置的斜率,积分就是求图像的面积。
然后在多项式就是这样:
求导: F ′ ( x ) = ∑ i = 0 F [ i ] i x i − 1 F'(x)=\sum\limits_{i=0}F[i]ix^{i-1} F(x)=i=0F[i]ixi1
积分: F ′ ( x ) = C + ∑ i = 0 F [ i ] x i + 1 / i F'(x)=C+\sum\limits_{i=0}F[i]x^{i+1}/i F(x)=C+i=0F[i]xi+1/i

然后这里简单贴一下代码:

void dao(int *f, int m) {
	for (int i = 1; i < m; i++)
		f[i - 1] = mul(f[i], i);
	f[m - 1] = 0;
}

void ji(int *f, int m) {
	for (int i = m; i >= 1; i--)
		f[i] = mul(f[i - 1], inv[i]);
	f[0] = 0;
}

然后多项式的复合就直接是把 x x x 带入到 G ( x ) G(x) G(x),就可以得到:
F ( G ( x ) ) = ∑ i = 0 F [ i ] G i ( x ) F(G(x))=\sum\limits_{i=0}F[i]G^i(x) F(G(x))=i=0F[i]Gi(x)

多项式牛顿迭代

那在知道牛顿迭代之前我们要知道一个叫做泰勒展开的东西。

我们设 F ( n ) ( x ) F^{(n)}(x) F(n)(x) f ( x ) f(x) f(x) n n n 阶导数,那么就有泰勒展开的式子:
F ( x ) = ∑ i = 0 F ( i ) ( a ) i ! ( x − a ) i F(x)=\sum\limits_{i=0}\dfrac{F^{(i)}(a)}{i!}(x-a)^i F(x)=i=0i!F(i)(a)(xa)i
(不要问我怎么证,我也不会)
(大概就是你要推这个函数,你就尝试着不断拟合,即“得寸进尺”,让它所有阶导的结果都一样)
(所以你这个 i i i 弄到越后拟合的就越像)

然后就有一些东西:
麦克劳林级数:其实就是 a = 0 a=0 a=0 的泰勒展开:
F ( x ) = ∑ i = 0 F ( i ) ( a ) i ! x i F(x)=\sum\limits_{i=0}\dfrac{F^{(i)}(a)}{i!}x^i F(x)=i=0i!F(i)(a)xi
这个东西最多的应用就是用卷积的形式来定义一些东西,就比如 exp ⁡ x \exp x expx e x = ∑ i = 0 x i i ! e^x=\sum\limits_{i=0}\dfrac{x^i}{i!} ex=i=0i!xi。( exp ⁡ x \exp x expx 的导数是它自己)

然后就是多项式牛顿迭代了:
首先给一下式子:如果已知 G ( F ( x ) ) = 0 G(F(x))=0 G(F(x))=0,而且已经得知 G ( F ∗ ( x ) ) = 0 ( m o d x n 2 ) G(F_*(x))=0\pmod{x^{\frac{n}{2}}} G(F(x))=0(modx2n),那么就有 F ( x ) = F ∗ ( x ) − G ( F ∗ ( x ) ) G ′ ( F ∗ ( x ) ) F(x)=F_*(x)-\dfrac{G(F_*(x))}{G'(F_*(x))} F(x)=F(x)G(F(x))G(F(x))
一般来讲题目中的 G G G 函数都会比较简单,可以手动分析爆算。
然后不难看出可以倍增出来,常数项就特殊考虑一下。

然后这里给给证明:
首先 F ( x )     m o d   x n F(x)\ \bmod x^n F(x) modxn 显然也可以作为   m o d   x n 2 \bmod x^{\frac{n}{2}} modx2n 的一组解,所以 F ( x ) = F ∗ ( x ) ( m o d x n 2 ) F(x)=F_*(x)\pmod{x^{\frac{n}{2}}} F(x)=F(x)(modx2n)
然后我们把 G ( F ( x ) ) G(F(x)) G(F(x)) F ∗ ( x ) F_*(x) F(x) 的位置泰勒展开:
G ( F ( x ) ) = G ( F ∗ ( x ) ) + G ′ ( F ∗ ( x ) ) 1 ! ( F ( x ) − F ∗ ( x ) ) + G ′ ′ ( F ∗ ( x ) ) 2 ! ( F ( x ) − F ∗ ( x ) ) 2 + . . . G(F(x))=G(F_*(x))+\dfrac{G'(F_*(x))}{1!}(F(x)-F_*(x))+\dfrac{G''(F_*(x))}{2!}(F(x)-F_*(x))^2+... G(F(x))=G(F(x))+1!G(F(x))(F(x)F(x))+2!G(F(x))(F(x)F(x))2+...

然后根据上面那个相等的式子可以发现 F ∗ ( x ) F_*(x) F(x) F ( x ) F(x) F(x) 的后 n 2 \dfrac{n}{2} 2n 项是相同的,所以 ( F ( x ) − F ∗ ( x ) ) (F(x)-F_*(x)) (F(x)F(x)) 的最低项次也是 x n 2 x^\frac{n}{2} x2n
然后 ( F ( x ) − F ∗ ( x ) ) 2 (F(x)-F_*(x))^2 (F(x)F(x))2 的最低项次是 x n x^n xn,然后依次就增加下去,然后你会发现从二次方开始因为你是 ( m o d x n ) \pmod{x^n} (modxn),这些项都被模掉变成 0 0 0 了。
然后就只剩下前面的两项,就是 G ( F ( x ) ) = G ( F ∗ ( x ) ) + G ′ ( F ∗ ( x ) ) 1 ! ( F ( x ) − F ∗ ( x ) ) G(F(x))=G(F_*(x))+\dfrac{G'(F_*(x))}{1!}(F(x)-F_*(x)) G(F(x))=G(F(x))+1!G(F(x))(F(x)F(x))
然后你整理一下,就得到了:
F ( x ) = F ∗ ( x ) − G ( F ∗ ( x ) ) G ′ ( F ∗ ( x ) ) F(x)=F_*(x)-\dfrac{G(F_*(x))}{G'(F_*(x))} F(x)=F(x)G(F(x))G(F(x))

然后这里有一些性质以及特点:
因为 G ( F ∗ ( x ) ) G(F_*(x)) G(F(x)) 的最低次项至少是 x n 2 x^{\frac{n}{2}} x2n,所以分母 G ′ ( F ∗ ( x ) ) G'(F_*(x)) G(F(x)) 的精度只需要 x n 2 x^{\frac{n}{2}} x2n
而且记得搞的时候自由元是 F ( x ) F(x) F(x)

求逆元

我们从 B ( x ) ∗ A ( x ) = 1 ( m o d x n ) B(x)*A(x)=1\pmod{x^n} B(x)A(x)=1(modxn),然后 A ( x ) A(x) A(x) 已知。
那我们可以设 G ( x ) G(x) G(x),使得 G ( B ( x ) ) = A ( x ) B ( x ) − 1 ( m o d x n ) G(B(x))=A(x)B(x)-1\pmod{x^n} G(B(x))=A(x)B(x)1(modxn),那这里 A ( x ) A(x) A(x) 就是系数。
然后我们可以导一下: G ′ ( B ( x ) ) = A ( x ) ( m o d x n ) G'(B(x))=A(x)\pmod{x^n} G(B(x))=A(x)(modxn)
那我们同样设 B ∗ ( x ) B_*(x) B(x) ( m o d x n 2 ) \pmod{x^{\frac{n}{2}}} (modx2n) 的解。
然后就牛顿迭代:
B ( x ) = B ∗ ( x ) − G ( B ∗ ( x ) ) G ′ ( B ∗ ( x ) ) = B ∗ ( x ) − A ( x ) B ∗ ( x ) − 1 A ( x ) B(x)=B_*(x)-\dfrac{G(B_*(x))}{G'(B_*(x))}=B_*(x)-\dfrac{A(x)B_*(x)-1}{A(x)} B(x)=B(x)G(B(x))G(B(x))=B(x)A(x)A(x)B(x)1

然后分母 1 A ( x ) \dfrac{1}{A(x)} A(x)1 的精度只用达到 x n 2 x^{\frac{n}{2}} x2n,所以我们可以直接用 B ∗ ( x ) B_*(x) B(x) 代替它,因为是逆元嘛。
B ( x ) = B ∗ ( x ) − ( A ( x ) B ∗ ( x ) − 1 ) B ∗ ( x ) = 2 B ∗ ( x ) − B ∗ ( x ) 2 A ( x ) B(x)=B_*(x)-(A(x)B_*(x)-1)B_*(x)=2B_*(x)-B_*(x)^2A(x) B(x)=B(x)(A(x)B(x)1)B(x)=2B(x)B(x)2A(x)
这个跟我们用平方的方法推出来的是一样的。

求开方

类似的方法,设 B ( x ) 2 − A ( x ) = 0 B(x)^2-A(x)=0 B(x)2A(x)=0 A ( x ) A(x) A(x) 已知。
G ( x ) G(x) G(x) G ( B ( x ) ) = B ( x ) 2 − A ( x ) G(B(x))=B(x)^2-A(x) G(B(x))=B(x)2A(x),然后 G ′ ( B ( x ) ) = 2 B ( x ) G'(B(x))=2B(x) G(B(x))=2B(x)
上牛叠: B ( x ) = B ∗ ( x ) − G ( B ∗ ( x ) ) G ′ ( B ∗ ( x ) ) = B ∗ ( x ) − B ∗ ( x ) 2 − A ( x ) 2 B ∗ ( x ) = A ( x ) + B ∗ ( x ) 2 2 B ∗ ( x ) B(x)=B_*(x)-\dfrac{G(B_*(x))}{G'(B_*(x))}=B_*(x)-\dfrac{B_*(x)^2-A(x)}{2B_*(x)}=\dfrac{A(x)+B_*(x)^2}{2B_*(x)} B(x)=B(x)G(B(x))G(B(x))=B(x)2B(x)B(x)2A(x)=2B(x)A(x)+B(x)2

然后就可以用这个来倍增。
然后一开始的常数项就是要看给的 A ( x ) A(x) A(x) 的常数项,就是它的二次剩余。
然而这道题直接说了常数项式 1 1 1,那就可以直接搞了。

代码

#include<cstdio>
#include<cstring>
#include<algorithm>
#define mo 998244353
#define clr(f, n) memset(f, 0, sizeof(int) * (n))
#define cpy(f, g, n) memcpy(f, g, sizeof(int) * (n)) 

using namespace std;

int n, a[800001], an[800001];
int G, Gv;

int read() {
	int re = 0, zf = 1;
	char c = getchar();
	while (c < '0' || c > '9') {
		if (c == '-') zf = -zf; c = getchar();
	}
	while (c >= '0' && c <= '9') {
		re = (re << 3) + (re << 1) + c - '0';
		c = getchar();
	}
	return re * zf;
}

void writee(int x) {
	if (x > 9) writee(x / 10);
	putchar(x % 10 + '0');
}

int write(int x) {
	if (x < 0) putchar('-'), x = -x;
	writee(x);
}

int mul(int x, int y) {
	return 1ll * x * y % mo;
}

int jia(int x, int y) {
	return (x + y) % mo;
}

int jian(int x, int y) {
	return (x - y + mo) % mo;
}

int ksm(int x, int y) {
	int re = 1;
	while (y) {
		if (y & 1) re = mul(re, x);
		x = mul(x, x); y >>= 1;
	}
	return re;
}

void NTT(int *f, int op, int limit) {
	for (int i = 0; i < limit; i++)
		if (an[i] > i) swap(f[i], f[an[i]]);
	for (int mid = 1; mid < limit; mid <<= 1) {
		int Wn = ksm(op == 1 ? G : Gv, (mo - 1) / (mid << 1));
		for (int R = (mid << 1), j = 0; j < limit; j += R) {
			int w = 1;
			for (int k = 0; k < mid; k++, w = mul(w, Wn)) {
				int x = f[j | k], y = mul(w, f[j | mid | k]);
				f[j | k] = jia(x, y); f[j | mid | k] = jian(x, y);
			}
		}
	}
	if (op == -1) {
		int invl = ksm(limit, mo - 2);
		for (int i = 0; i < limit; i++)
			f[i] = mul(f[i], invl);
	}
}

void px(int *f, int *g, int n) {
	for (int i = 0; i < n; i++)
		f[i] = mul(f[i], g[i]);
}

void get_an(int limit, int l_size) {
	for (int i = 0; i < limit; i++) an[i] = (an[i >> 1] >> 1) | ((i & 1) << (l_size - 1));
}

void cheng(int *f, int *g, int n, int m) {
	static int tmp[800001];
	cpy(tmp, g, m);
	int limit = 1, l_size = 0;
	for (; limit < n + m - 1; limit <<= 1, l_size++);
	get_an(limit, l_size);
	NTT(f, 1, limit); NTT(tmp, 1, limit);
	px(f, tmp, limit); NTT(f, -1, limit);
	clr(f + n, limit - n); clr(tmp, limit);
}

void invp(int *f, int m) {
	static int tmp[800001], r[800001], w[800001];
	int n = 1, l_size = 0;
	for (; n < m; n <<= 1);
	w[0] = ksm(f[0], mo - 2);
	for (int len = 2; len <= n; len <<= 1) {
		l_size++;
		get_an(len, l_size);
		
		for (int i = 0; i < (len >> 1); i++) r[i] = w[i];
		cpy(tmp, f, len);
		NTT(tmp, 1, len); NTT(r, 1, len); px(r, tmp, len); NTT(r, -1, len);
		clr(r, len >> 1);
		cpy(tmp, w, len);
		NTT(tmp, 1, len); NTT(r, 1, len); px(r, tmp, len); NTT(r, -1, len);
		
		for (int i = len >> 1; i < len; i++)
			w[i] = jian(jia(w[i], w[i]), r[i]);
	}
	cpy(f, w, m); clr(w, n); clr(r, n); clr(tmp, n);
}

void pow2p(int *f, int limit, int l_size) {
	get_an(limit, l_size);
	NTT(f, 1, limit); px(f, f, limit); NTT(f, -1, limit);
}

void sqrtp(int *f, int m) {
	static int b1[800001], b2[800001];
	int limit = 1;
	for (; limit < m; limit <<= 1);
	b1[0] = 1; int l_size = 0;
	for (int len = 2; len <= limit; len <<= 1) {
		l_size++;
		for (int i = 0; i < (len >> 1); i++) b2[i] = jia(b1[i], b1[i]);
		invp(b2, len);
		pow2p(b1, len, l_size);
		for (int i = 0; i < len; i++) b1[i] = jia(b1[i], f[i]);
		cheng(b1, b2, len, len);
	}
	cpy(f, b1, m); clr(b1, limit + limit); clr(b2, limit + limit);
}

int main() {
//	freopen("read.txt", "r", stdin); 
	
	G = 3; Gv = ksm(G, mo - 2);
	
	n = read();
	for (int i = 0; i < n; i++) a[i] = read();
	
	sqrtp(a, n);
	for (int i = 0; i < n; i++) write(a[i]), putchar(' ');
	
	return 0;
}