萌新求助分治 NTT
查看原帖
萌新求助分治 NTT
496840
SAMSHAWCRAFT楼主2021/12/25 23:46

确定了多项式板子没错,这个多项式板子交到模板题和其他生成函数题没发现出锅。和题解的代码对拍了,我的分治 NTT 写错了,但是改了一晚上改不出来了,不明白哪里写错了。希望大佬能帮我指一下,十分感谢!

分治 NTT 就是那个 cdq 函数,名字是瞎起的。

#include <stdio.h>
#include <string.h>
#include <algorithm>
#include <numeric>
#define qaq inline
using ll=long long;
const ll mod=998244353;
const int sz=1e6+19;
int n,m,t,k,revid[sz],lim,limbit;
ll a[sz],b[sz];
ll fac[sz],invfac[sz];
ll F[sz],G[sz],f[sz],g[sz];
qaq ll fastPow(ll a,ll b,const ll &mod){
    ll res=1;
    for(ll base=a%mod,t=b;t!=0;t>>=1,base=base*base%mod)
        if(t&1) res=res*base%mod;
    return res;
}
qaq ll inverse(ll n,const ll &mod){
    return fastPow(n,mod-2,mod);
}
qaq void NTT(int limit,ll *arr,int sign){
    int invlim=inverse(limit,mod);
    for(int cx=0;cx<limit;++cx)
        if(cx<revid[cx]) std::swap(arr[cx],arr[revid[cx]]);
    for(int l=2;l<=limit;l<<=1){
        ll Wn=fastPow(sign==1?3:inverse(3,mod),(mod-1)/l,mod);
        for(int cx=0;cx<limit;cx+=l){
            ll w=1;
            for(int cy=0;cy<(l>>1);++cy,w=w*Wn%mod){
                ll tmp1=arr[cx+cy];
                ll tmp2=w*arr[cx+cy+(l>>1)]%mod;
                arr[cx+cy]=(tmp1+tmp2)%mod;
                arr[cx+cy+(l>>1)]=(tmp1-tmp2+mod)%mod;
            }
        }
    }
    if(sign==-1){
        for(int cx=0;cx<limit;++cx)
            arr[cx]=arr[cx]*invlim%mod;
    }
}
qaq void polydiff(int limit,ll *f,ll *g){
    for(int cx=1;cx<limit;++cx)
        g[cx-1]=f[cx]*cx%mod;
    g[limit-1]=0;
}
qaq void polyinteg(int limit,ll *f,ll *g){
    for(int cx=1;cx<limit;++cx)
        g[cx]=f[cx-1]*inverse(cx,mod)%mod;
    g[0]=0;
}
qaq void polyinv(int n,ll *f,ll *g){
    static ll c[sz];
    std::fill(c,c+sz,0);
    if(n==1){
        g[0]=inverse(f[0],mod);
        return;
    }
    polyinv((n+1)>>1,f,g);
    for(lim=1,limbit=0;lim<2LL*n;lim<<=1,limbit++);
    for(int cx=1;cx<lim;++cx)
        revid[cx]=(revid[cx>>1]>>1)|((cx&1)<<(limbit-1));
    for(int cx=0;cx<n;++cx)
        c[cx]=f[cx];
    for(int cx=n;cx<lim;++cx)
        c[cx]=0;
    NTT(lim,g,1),NTT(lim,c,1);
    for(int cx=0;cx<lim;++cx)
        g[cx]=(2LL-g[cx]*c[cx]%mod+mod)%mod*g[cx]%mod;
    NTT(lim,g,-1);
    for(int cx=n;cx<lim;++cx)
        g[cx]=0;
}
void polyln(int n,ll *f,ll *g){
    static ll df[sz],invf[sz];
    std::fill(df,df+sz,0);
    std::fill(invf,invf+sz,0);
    polydiff(n,f,df);
    polyinv(n,f,invf);
    for(lim=1,limbit=0;lim<2LL*n;lim<<=1,limbit++);
    for(int cx=1;cx<lim;++cx)
        revid[cx]=(revid[cx>>1]>>1)|((cx&1)<<(limbit-1));
    NTT(lim,df,1),NTT(lim,invf,1);
    for(int cx=0;cx<lim;++cx)
        df[cx]=df[cx]*invf[cx]%mod;
    NTT(lim,df,-1);
    polyinteg(n,df,g);
}
void facPrepare(int n=sz-1){
    fac[0]=1;
    for(ll cx=1;cx<=n;++cx)
        fac[cx]=fac[cx-1]*cx%mod;
    invfac[n]=inverse(fac[n],mod);
    for(ll cx=n;cx!=0;--cx)
        invfac[cx-1]=invfac[cx]*cx%mod;
}
void cdq(int ln,int rn,ll *arr,ll *R){
    if(ln==rn){
        R[0]=1,R[1]=mod-arr[ln];
        return;
    }
    static ll lf[sz],rf[sz];
    int mid=(ln+rn)>>1;
    cdq(ln,mid,arr,lf);
    cdq(mid+1,rn,arr,rf);
    for(lim=1,limbit=0;lim<=(rn-ln+1);lim<<=1,limbit++);
    for(int cx=1;cx<lim;++cx)
        revid[cx]=(revid[cx>>1]>>1)|((cx&1)<<(limbit-1));
    for(int cx=mid-ln+2;cx<lim;++cx)
        lf[cx]=0;
    for(int cx=rn-mid+1;cx<lim;++cx)
        rf[cx]=0;
    NTT(lim,lf,1),NTT(lim,rf,1);
    for(int cx=0;cx<lim;++cx)
        lf[cx]=lf[cx]*rf[cx]%mod;
    NTT(lim,lf,-1);
    for(int cx=0;cx<rn-ln+2;++cx)
        R[cx]=lf[cx];
}
int main(){
    scanf("%d%d",&n,&m);
    for(int cx=1;cx<=n;++cx)
        scanf("%lld",a+cx);
    for(int cx=1;cx<=m;++cx)
        scanf("%lld",b+cx);
    scanf("%d",&k);
    t=k,k=std::max({n,m,k});
    facPrepare();
    cdq(1,n,a,F);
    polyln(k+2,F,f);
    polydiff(k+2,f,F);
    cdq(1,m,b,G);
    polyln(k+2,G,g);
    polydiff(k+2,g,G);
    for(int cx=k;cx!=0;--cx)
        F[cx]=(mod-F[cx-1]),G[cx]=(mod-G[cx-1]);
    F[0]=n,G[0]=m;
    for(int cx=0;cx<=k;++cx)
        F[cx]=F[cx]*invfac[cx]%mod;
    for(int cx=0;cx<=k;++cx)
        G[cx]=G[cx]*invfac[cx]%mod;
    for(lim=1,limbit=0;lim<2LL*k+4;lim<<=1,limbit++);
    for(int cx=1;cx<lim;++cx)
        revid[cx]=(revid[cx>>1]>>1)|((cx&1)<<(limbit-1));
    NTT(lim,F,1),NTT(lim,G,1);
    for(int cx=0;cx<lim;++cx)
        F[cx]=F[cx]*G[cx]%mod;
    NTT(lim,F,-1);
    ll invnm=inverse(1LL*n*m,mod);
    for(int cx=1;cx<=t;++cx)
        printf("%lld\n",F[cx]*fac[cx]%mod*invnm%mod);
    return 0;
}
2021/12/25 23:46
加载中...