求助
查看原帖
求助
982518
sjwhsss楼主2024/12/18 20:28

RT,跟题解对排发现最后求ln的时候逆元部分有问题,但不知道在哪,代码可以AC逆元,ln,exp,并且替换了前面几道题原来的代码还是一样

#include <bits/stdc++.h>
#define ll long long
#define int long long
using namespace std;
const int maxn = 1e6+5 , mod = 998244353 , G = 3 , Gi = 332748118;
int a[maxn] , b[maxn] , a_[maxn] , inva[maxn] , sa[maxn] , c[maxn] , r[maxn] , lnb[maxn] , ln[maxn];
inline ll read()
{
	char ch = getchar(); ll x = 0;
	while(isdigit(ch)^1)ch=getchar();
	while(isdigit(ch))x=((x<<1)+(x<<3)+(ch^48))%mod,ch=getchar();
	return x;
}
inline ll qpow(ll a , ll b)
{
	ll res = 1;
	while(b)
	{
		if (b&1)(res*=a)%=mod;
		(a*=a)%=mod;
		b>>=1;
	}
	return res;
}
void NTT(int *a , int lim , int t)
{
	for (int i = 1; i < lim; i++) if (i < r[i])swap(a[i] , a[r[i]]);
	for (int i = 1; i < lim; i<<=1)
	{
		ll ome = qpow(t == 1 ? G : Gi , (mod - 1)/(i<<1));
		for (int j = 0; j < lim; j+=i<<1)
		{
			ll w = 1;
			for (int k = 0; k < i; k++ , (w*=ome)%=mod)
			{
				ll x = a[j + k] , y = w * a[j + k + i] % mod;
				a[j + k] = (x + y)%mod;
				a[j + k + i] = (x - y + mod)%mod;
			}
		}
	}
	if (t == 1)return;
	ll inv = qpow(lim , mod - 2);
	for (int i = 0; i < lim; i++)(a[i]*=inv)%=mod;
	return;
}
void Mul(int *a , int *b , int n , int m)
{
	int lim = 1 , t = 0;
	while(lim < n + m)lim<<=1,t++;
	for (int i = 1; i < lim; i++)r[i]=(r[i>>1]>>1)|((i&1)<<t-1);
	NTT(a , lim , 1) , NTT(b , lim , 1);
	for (int i = 0; i < lim; i++)(a[i]*=b[i])%=mod;
	NTT(a , lim , -1);
	return;
}
void GetInv(int *a , int *b , int n)
{
	if (n == 1){b[0]=qpow(a[0] , mod - 2);return;}
	GetInv(a , b , n+1>>1);
	int lim = 1 , t = 0;
	while(lim < (n<<1))lim<<=1,t++;
	for (int i = 1; i < lim; i++)r[i]=(r[i>>1]>>1)|((i&1)<<t-1);
	for (int i = 0; i < n; i++)c[i]=a[i];
	for (int i = n; i < lim; i++)c[i]=0;
	NTT(c , lim , 1) , NTT(b , lim , 1);
//	for (int i = 0; i < n; i++) cout << b[i] << " ";
//	cout << endl;
	for (int i = 0; i < lim; i++) b[i]=b[i]*(2 - c[i]*b[i]%mod+mod)%mod;
	
	NTT(b , lim , -1);
	for (int i = n; i < lim; i++)b[i]=0;
	return;
}
void Derivation(int *a , int *b , int n)
{
	for (int i = 1; i < n; i++)b[i - 1]=1ll*a[i]*i%mod;
	b[n - 1]=0;
	return;
}
void Integral(int *a , int *b , int n)
{
	for (int i = 1; i < n; i++)b[i]=1ll*a[i - 1]*qpow(i , mod - 2)%mod;
	b[0]=0;
	return;
}
void GetLn(int *a , int *b , int n)
{
	for (int i = 0; i < n; i++)a_[i]=inva[i]=sa[i]=c[i]=0;
	Derivation(a , a_ , n);
//	cout << endl;
	GetInv(a , inva , n);
//	for (int i = 0; i < n; i++) cout << a[i] << " ";
//	cout << endl;
	Mul(a_ , inva , n , n);
	Integral(a_ , sa , n);
	for (int i = 0; i < n; i++)b[i]=sa[i];
	return;
}
void GetExp(int *a , int *b , int n)
{
	if (n == 1){b[0]=1;return;}
	GetExp(a , b , n+1>>1);
//	for (int i = 0; i < n; i++) cout << b[i] << " ";
//	cout << endl;
	GetLn(b , ln , n);
	for (int i = 0; i < n; i++)ln[i]=(a[i]-ln[i]+mod)%mod;
//	cout << endl;
	ln[0]++;
	Mul(b , ln , n , n);
	for (int i = n; i < (n<<1); i++)b[i]=ln[i]=0;
	return;
}
void GetPow(int *a , int *b , int n , ll k)
{
	GetLn(a , lnb , n);
	for (int i = 0; i < n; i++)(lnb[i]*=k)%=mod,ln[i]=0;
	GetExp(lnb , b , n);
	for (int i = 0; i < n; i++)b[i]%=mod;
	return;
}
signed main ()
{
	int n = read() , lim = 1;
	ll k = read();
	for (int i = 0; i < n; i++)a[i]=read();
	while(lim < n)lim<<=1;
	GetPow(a , b , lim , k);
//	cout << lim << endl;
//	GetExp(a , b , lim);
	for (int i = 0; i < n; i++)printf("%lld " , b[i]);
	return 0;
}
2024/12/18 20:28
加载中...