MnZn求助
查看原帖
MnZn求助
982518
sjwhsss楼主2024/12/16 21:40

RT,已知Mul函数有问题(NTT和Div没问题,可以过逆元模板),但不知道在哪,把FFT模板题的代码复制过来也是错的,目测传参没问题

#include <bits/stdc++.h>
#define int long long
using namespace std;
const int maxn = 5e5+5 , mod = 998244353 , g = 3 , gi = 332748181;
int n , a[maxn] , ap[maxn] , r[maxn] , c[maxn] , a_[maxn] , _a[maxn] , lna[maxn];
int qpow(int a , int b)
{
	int res = 1;
	while(b)
	{
		if (b&1)(res*=a)%=mod;
		(a*=a)%=mod;
		b>>=1;
	}
	return res;
}
void NTT(int a[] , int lim , int t)
{
	for (int i = 1; i < lim; i++) if (i < r[i])swap(a[i],a[r[i]]);
	for (int i = 1; i < lim; i<<=1)
	{
		int ome = qpow(t == 1 ? g : gi , (mod - 1) / (i<<1));
		for (int j = 0; j < lim; j+=i<<1)
		{
			int w = 1;
			for (int k = 0; k < i; k++ , (w*=ome)%=mod)
			{
				int x = a[j + k] , y = w * a[j + k + i] % mod;
				a[j + k] = (x + y)%mod;
				a[j + k + i] = (x - y + mod)%mod;
			}
		}
	}
	if (t == 1) return;
	int inv = qpow(lim , mod - 2);
	for (int i = 0; i < lim; i++) a[i]=(a[i]%mod*inv+mod)%mod;
	return;
}
void Div(int a[] , int b[] , int n)
{
	if (n == 1){b[0]=qpow(a[0] , mod - 2);return;}
	Div(a , b , n+1>>1);
	int lim = 1 , t = 0;
	while(lim < (n<<1))lim<<=1,t++;
	for (int i = 1; i < lim; i++)r[i]=(r[i>>1]>>1)|((i&1)<<t-1);
	for (int i = 0; i < n; i++)c[i]=a[i];
	for (int i = n; i < lim; i++)c[i]=0;
	NTT(c , lim , 1) , NTT(b , lim , 1);
	for (int i = 0; i < lim; i++) b[i]=b[i]*(2 - c[i]*b[i]%mod+mod)%mod;
	NTT(b , lim , -1);
	for (int i = n; i < lim; i++)b[i]=0;
	return;
}
void Derivation(int a[] , int b[] , int n)
{
	for (int i = 1; i < n; i++) b[i - 1]=a[i]*i%mod;
	b[n-1]=0;
	return;
}
void Integral(int a[] , int b[] , int n)
{
	for (int i = 1; i < n; i++)b[i]=a[i - 1]*qpow(i , mod - 2)%mod;
	b[0]=0;
	return;
}
void Mul(int a[] , int b[] , int n , int m)
{
	int lim = 1 , t = 0;
	while(lim <= n + m)lim<<=1,t++;
	for (int i = 1; i < lim; i++)r[i]=(r[i>>1]>>1)|((i&1)<<t-1);
	NTT(a , lim , 1) , NTT(b , lim , 1);
	for (int i = 0; i < lim; i++)(a[i]*=b[i])%=mod;
	NTT(a , lim , -1);
	return;
}
void Getln(int a[] , int n)
{
	Derivation(a , a_ , n);
	Div(a , _a , n);
	Mul(a_ , _a , n - 1 , n);
	Integral(a_ , lna , n);
}
signed main ()
{
	int m;
	scanf("%lld" , &n);
	for (int i = 0; i < n; i++)scanf("%lld" , &a[i]);
	Getln(a , n);
	for (int i = 0; i < n; i++)printf("%lld " , lna[i]);
	return 0;
}
2024/12/16 21:40
加载中...