确定了多项式板子没错,这个多项式板子交到模板题和其他生成函数题没发现出锅。和题解的代码对拍了,我的分治 NTT 写错了,但是改了一晚上改不出来了,不明白哪里写错了。希望大佬能帮我指一下,十分感谢!
分治 NTT 就是那个 cdq 函数,名字是瞎起的。
#include <stdio.h>
#include <string.h>
#include <algorithm>
#include <numeric>
#define qaq inline
using ll=long long;
const ll mod=998244353;
const int sz=1e6+19;
int n,m,t,k,revid[sz],lim,limbit;
ll a[sz],b[sz];
ll fac[sz],invfac[sz];
ll F[sz],G[sz],f[sz],g[sz];
qaq ll fastPow(ll a,ll b,const ll &mod){
ll res=1;
for(ll base=a%mod,t=b;t!=0;t>>=1,base=base*base%mod)
if(t&1) res=res*base%mod;
return res;
}
qaq ll inverse(ll n,const ll &mod){
return fastPow(n,mod-2,mod);
}
qaq void NTT(int limit,ll *arr,int sign){
int invlim=inverse(limit,mod);
for(int cx=0;cx<limit;++cx)
if(cx<revid[cx]) std::swap(arr[cx],arr[revid[cx]]);
for(int l=2;l<=limit;l<<=1){
ll Wn=fastPow(sign==1?3:inverse(3,mod),(mod-1)/l,mod);
for(int cx=0;cx<limit;cx+=l){
ll w=1;
for(int cy=0;cy<(l>>1);++cy,w=w*Wn%mod){
ll tmp1=arr[cx+cy];
ll tmp2=w*arr[cx+cy+(l>>1)]%mod;
arr[cx+cy]=(tmp1+tmp2)%mod;
arr[cx+cy+(l>>1)]=(tmp1-tmp2+mod)%mod;
}
}
}
if(sign==-1){
for(int cx=0;cx<limit;++cx)
arr[cx]=arr[cx]*invlim%mod;
}
}
qaq void polydiff(int limit,ll *f,ll *g){
for(int cx=1;cx<limit;++cx)
g[cx-1]=f[cx]*cx%mod;
g[limit-1]=0;
}
qaq void polyinteg(int limit,ll *f,ll *g){
for(int cx=1;cx<limit;++cx)
g[cx]=f[cx-1]*inverse(cx,mod)%mod;
g[0]=0;
}
qaq void polyinv(int n,ll *f,ll *g){
static ll c[sz];
std::fill(c,c+sz,0);
if(n==1){
g[0]=inverse(f[0],mod);
return;
}
polyinv((n+1)>>1,f,g);
for(lim=1,limbit=0;lim<2LL*n;lim<<=1,limbit++);
for(int cx=1;cx<lim;++cx)
revid[cx]=(revid[cx>>1]>>1)|((cx&1)<<(limbit-1));
for(int cx=0;cx<n;++cx)
c[cx]=f[cx];
for(int cx=n;cx<lim;++cx)
c[cx]=0;
NTT(lim,g,1),NTT(lim,c,1);
for(int cx=0;cx<lim;++cx)
g[cx]=(2LL-g[cx]*c[cx]%mod+mod)%mod*g[cx]%mod;
NTT(lim,g,-1);
for(int cx=n;cx<lim;++cx)
g[cx]=0;
}
void polyln(int n,ll *f,ll *g){
static ll df[sz],invf[sz];
std::fill(df,df+sz,0);
std::fill(invf,invf+sz,0);
polydiff(n,f,df);
polyinv(n,f,invf);
for(lim=1,limbit=0;lim<2LL*n;lim<<=1,limbit++);
for(int cx=1;cx<lim;++cx)
revid[cx]=(revid[cx>>1]>>1)|((cx&1)<<(limbit-1));
NTT(lim,df,1),NTT(lim,invf,1);
for(int cx=0;cx<lim;++cx)
df[cx]=df[cx]*invf[cx]%mod;
NTT(lim,df,-1);
polyinteg(n,df,g);
}
void facPrepare(int n=sz-1){
fac[0]=1;
for(ll cx=1;cx<=n;++cx)
fac[cx]=fac[cx-1]*cx%mod;
invfac[n]=inverse(fac[n],mod);
for(ll cx=n;cx!=0;--cx)
invfac[cx-1]=invfac[cx]*cx%mod;
}
void cdq(int ln,int rn,ll *arr,ll *R){
if(ln==rn){
R[0]=1,R[1]=mod-arr[ln];
return;
}
static ll lf[sz],rf[sz];
int mid=(ln+rn)>>1;
cdq(ln,mid,arr,lf);
cdq(mid+1,rn,arr,rf);
for(lim=1,limbit=0;lim<=(rn-ln+1);lim<<=1,limbit++);
for(int cx=1;cx<lim;++cx)
revid[cx]=(revid[cx>>1]>>1)|((cx&1)<<(limbit-1));
for(int cx=mid-ln+2;cx<lim;++cx)
lf[cx]=0;
for(int cx=rn-mid+1;cx<lim;++cx)
rf[cx]=0;
NTT(lim,lf,1),NTT(lim,rf,1);
for(int cx=0;cx<lim;++cx)
lf[cx]=lf[cx]*rf[cx]%mod;
NTT(lim,lf,-1);
for(int cx=0;cx<rn-ln+2;++cx)
R[cx]=lf[cx];
}
int main(){
scanf("%d%d",&n,&m);
for(int cx=1;cx<=n;++cx)
scanf("%lld",a+cx);
for(int cx=1;cx<=m;++cx)
scanf("%lld",b+cx);
scanf("%d",&k);
t=k,k=std::max({n,m,k});
facPrepare();
cdq(1,n,a,F);
polyln(k+2,F,f);
polydiff(k+2,f,F);
cdq(1,m,b,G);
polyln(k+2,G,g);
polydiff(k+2,g,G);
for(int cx=k;cx!=0;--cx)
F[cx]=(mod-F[cx-1]),G[cx]=(mod-G[cx-1]);
F[0]=n,G[0]=m;
for(int cx=0;cx<=k;++cx)
F[cx]=F[cx]*invfac[cx]%mod;
for(int cx=0;cx<=k;++cx)
G[cx]=G[cx]*invfac[cx]%mod;
for(lim=1,limbit=0;lim<2LL*k+4;lim<<=1,limbit++);
for(int cx=1;cx<lim;++cx)
revid[cx]=(revid[cx>>1]>>1)|((cx&1)<<(limbit-1));
NTT(lim,F,1),NTT(lim,G,1);
for(int cx=0;cx<lim;++cx)
F[cx]=F[cx]*G[cx]%mod;
NTT(lim,F,-1);
ll invnm=inverse(1LL*n*m,mod);
for(int cx=1;cx<=t;++cx)
printf("%lld\n",F[cx]*fac[cx]%mod*invnm%mod);
return 0;
}