3模MTT求助
查看原帖
3模MTT求助
982518
sjwhsss楼主2024/12/19 15:03

RT,不开__int128会爆longlong,所有能模的都模了,修改取模的地方得分从0~50都有,不知道哪里有问题

#include <bits/stdc++.h>
#define int __int128
using namespace std;
const int maxn = 1e6+5 , mod1 = 998244353 , mod2 = 469762049 , mod3 = 1004535809 , G = 3 , G1 = 332748118 , G2 = 156587350 , G3 = 334845270;
int a1[maxn] , a2[maxn] , a3[maxn] , b1[maxn] , b2[maxn] , b3[maxn] , ans[maxn] , r[maxn];
int qpow(int a , int b , int mod)
{
	int res = 1;
	while(b)
	{
		if (b & 1)(res*=a)%=mod;
		(a*=a)%=mod;
		b>>=1;
	}
	return res;
}
void NTT(int *a , int lim , int t , int mod , int Gi)
{
	for (int i = 1; i < lim; i++) if (i < r[i])swap(a[i] , a[r[i]]);
	for (int i = 1; i < lim; i<<=1)
	{
		int ome = qpow(t == 1 ? G : Gi , (mod - 1)/(i<<1) , mod);
		for (int j = 0; j < lim; j+=i<<1)
		{
			int w = 1;
			for (int k = 0; k < i; k++ , (w*=ome)%=mod)
			{
				int x = a[j + k] , y = w * a[j + k + i] % mod;
				a[j + k] = (x + y)%mod;
				a[j + k + i] = (x - y + mod)%mod;
			}
		}
	}
	if (t == 1)return;
	int inv = qpow(lim , mod - 2 , mod);
	for (int i = 0; i < lim; i++) (a[i]*=inv)%=mod;
	return;
}
void Mul(int *a , int *b , int n , int m , int mod , int Gi)
{
	int lim = 1 , t = 0;
	while(lim <= n + m)lim<<=1,t++;
	for (int i = 1; i < lim; i++) r[i]=(r[i>>1]>>1)|((i&1)<<t-1);
	NTT(a , lim , 1 , mod , Gi) , NTT(b , lim , 1 , mod , Gi);
	for (int i = 0; i < lim; i++)(a[i]*=b[i])%=mod;
	NTT(a , lim , -1 , mod , Gi);
	return;
}
inline int read()
{
	char ch=getchar();int x=0;
	while(isdigit(ch)^1)ch=getchar();
	while(isdigit(ch))x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
	return x;
}
inline void print(int x)
{
	if (x > 9)print(x/10);
	putchar(x%10^48);
	return;
}
signed main ()
{
	int n=read() , m=read() , p=read();
	for (int i = 0; i <= n; i++) a2[i]=a3[i]=a1[i]=read();
	for (int i = 0; i <= m; i++) b2[i]=b3[i]=b1[i]=read();
	Mul(a1 , b1 , n , m , mod1 , G1);
	Mul(a2 , b2 , n , m , mod2 , G2);
	Mul(a3 , b3 , n , m , mod3 , G3);
	for (int i = 0; i < n + m + 1; i++)
	{
		int x = ((a2[i] - a1[i] + mod2 + mod2)%mod2 * qpow(mod1 , mod2 - 2 , mod2)%mod2 * mod1%(mod1*mod2) + a1[i])%(mod1*mod2);
		ans[i] = ((((a3[i] - x%mod3 + mod3 + mod3)%mod3%(mod1*mod2*mod3) * qpow(mod1%mod3 * mod2%mod3 , mod3 - 2 , mod3)%mod3)%(mod1*mod2*mod3) * (mod1*mod2)%(mod1*mod2*mod3))%(mod1*mod2*mod3) + x%(mod1*mod2*mod3))%p;
	}
	for (int i = 0; i < n + m + 1; i++) print(ans[i]),putchar(32);
	return 0;
}
2024/12/19 15:03
加载中...