萌新刚出生,求调代码
查看原帖
萌新刚出生,求调代码
1125685
Frielen楼主2024/12/3 14:12

评测记录

但是普通版这份代码是能过的。

这个 52pts 的挂法好像还挺普遍的。但我分不到 60pts,看不到挂的原因。并且也没有看到这种警示。

代码如下

#include<cstdio>
#include<cmath>
#include<algorithm>
#define int long long
using std::swap;
using std::reverse;
const int N=3e6+9,p=998244353;
struct Complex{
	int x,y;
};
int w;
Complex multi(Complex a,Complex b,int mod){
	Complex res;
	res.x=(((a.x*b.x)%mod+((a.y*b.y)%mod*w)%mod)%mod+mod)%mod;
	res.y=(((a.x*b.y)%mod+(a.y*b.x)%mod)%mod+mod)%mod;
	return res;
}
int qpow(int a,int p,int mod){
	int res=1;
	while(p){
		if(p&1) res=(res*a)%mod;
		a=(a*a)%mod;
		p>>=1;
	}
	return res;
}
int compqpow(Complex a,int p,int mod){
	Complex res={1,0};
	while(p){
		if(p&1) res=multi(res,a,mod);
		a=multi(a,a,mod);
		p>>=1;
	}
	return res.x%mod;
}
int cipolla(int n,int p){
	if(qpow(n,(p-1)/2,p)==p-1) return -1;
	int r;
	while(1){
		r=rand()%p;
		w=(((r*r)%p-n)%p+p)%p;
		if(qpow(w,(p-1)/2,p)==p-1) break;
	}
	return compqpow({r,1},(p+1)/2,p);
}
int qpow(int a,int k){
	int base=1;
	while(k){
		if(k&1) base=(base*a)%p;
		a=(a*a)%p;
		k>>=1;
	}
	return base%p;
}
int r[N],inv[N];
int s[N];
void NTT(int *A,int len,int typ){
    int lim=1,l=0;
    while(lim<len) lim<<=1,l++;
    for(int i=0;i<lim;i++) r[i]=(r[i>>1]>>1)|((i&1)<<(l-1)); for(int i=0;i<lim;i++) if(i<r[i]) swap(A[i],A[r[i]]);
    for(int mid=1;mid<lim;mid<<=1){
        int R=mid<<1,Wn=qpow(3,(p-1)/R);
		s[0]=1;
        for(int j=1;j<mid;j++) s[j]=s[j-1]*Wn%p;
        for(int j=0;j<lim;j+=R){
            for(int k=0;k<mid;k++){
                int x=A[j+k],y=s[k]*A[j+k+mid]%p;
                A[j+k]=(x+y)%p,A[j+k+mid]=(x-y+p)%p;
            }
        }
    }
    if(typ==-1){
        reverse(A+1,A+lim);
        for(int i=0,Inv=inv[lim];i<lim;i++) A[i]=A[i]*Inv%p;
    }
}
void deri/*derivation*/(int *A,int *B,int len){
	for(int i=1;i<len;i++) B[i-1]=A[i]*i%p;
	B[len-1]=0;
}
void inte/*integral*/(int *A,int *B,int len){
	for(int i=1;i<len;i++) B[i]=A[i-1]*inv[i]%p;
	B[0]=0;
}
int A[N],B[N],C[N],D[N],Ans1[N],ln[N],tot;
void NTT_inverse(int *a,int *b,int len){
    if(len==1){
    	b[0]=qpow(a[0],p-2);
    	return;
	}
    NTT_inverse(a,b,len>>1);
    for(int i=0;i<len;i++) C[i]=a[i],D[i]=b[i];
    int l=(len<<1);
    NTT(C,l,1);
	NTT(D,l,1);
    for(int i=0;i<l;i++) C[i]=(C[i]*D[i]%p)*D[i]%p;
    NTT(C,l,-1);
    for(int i=0;i<len;i++) b[i]=(b[i]+b[i]-C[i]+p)%p;
    for(int i=0;i<l;i++) C[i]=D[i]=0;
}
void getln(int *a,int *b,int len){
	deri(a,A,len);
	NTT_inverse(a,B,len);
	int l=(len<<1);
    NTT(A,l,1);
	NTT(B,l,1);
    for(int i=0;i<l;i++) A[i]=(A[i]*B[i])%p;
    NTT(A,l,-1);
	inte(A,b,l);
    for(int i=0;i<l;i++) A[i]=B[i]=0;
}
void getexp(int *A,int *B,int len){
	if(len==1){
		B[0]=1;
		return;
	}
	getexp(A,B,len>>1);
	getln(B,ln,len);
	ln[0]=(A[0]+1-ln[0]+p)%p;
	int l=(len<<1);
    for(int i=1;i<len;i++) ln[i]=(A[i]-ln[i]+p)%p;
    NTT(ln,l,1);
    NTT(B,l,1);
    for(int i=0;i<l;i++) B[i]=B[i]*ln[i]%p;
    NTT(B,l,-1);
    for(int i=len;i<l;i++) B[i]=ln[i]=0;
}
int X[N],Y[N];
int min(int a,int b){
	return a<b?a:b;
}
void NTT_sqrt(int *a,int *b,int len){
	if(len==1){
		int t=cipolla(a[0]%p,p);
		b[0]=min(t,p-t);
		return;
	}
	NTT_sqrt(a,b,len>>1);
	int s=(len<<1);
	for(int i=0;i<s;i++) Y[i]=0;
	NTT_inverse(b,Y,len);
	int lim=1,l=0;
	while(lim<s) lim<<=1,l++;
	for(int i=0;i<len;i++) X[i]=a[i];
	for(int i=len;i<lim;i++) X[i]=0;
	NTT(X,lim,1);
	NTT(b,lim,1);
	NTT(Y,lim,1);
	for(int i=0;i<lim;i++) b[i]=qpow(2,p-2)*((b[i]+X[i]*Y[i]%p)%p)%p;
	NTT(b,lim,-1);
	for(int i=len;i<lim;i++) b[i]=0;
}
int read(){
	int x=0;
	char c=getchar();
	while(c<'0'||c>'9') c=getchar();
	while(c>='0'&&c<='9'){
		x=((x*10)+(int)(c-'0'))%p;
		c=getchar();
	}
	return x%p;
}
int A1[N],A2[N],Ans[N];
signed main(){
	int n,k,num=0,x,a0,lim=1,l;
	scanf("%lld",&n);
	while(lim<=n) lim<<=1;
    l=lim<<1;
	for(int i=0;i<=l;i++) inv[i]=qpow(i,p-2);
	k=read();
	bool flag=0;
	for(int i=0;i<n;i++){
		scanf("%lld",&x);
		!(x||flag)?num++:flag=1,A1[i-num]=x;
	}
	if(!flag){
		for(int i=0;i<n;i++) printf("0 ");
		return 0;
	}
	a0=A1[0];
	lim=1;
	while(lim<=n-num) lim<<=1;
    getln(A1,A2,lim);
    for(int i=0;i<n-num;i++) A2[i]=A2[i]*k%p;
    getexp(A2,Ans,lim);
    num*=k;
    for(int i=0;i<min(num,n);i++) printf("0 ");
    for(int i=num;i<n;i++) printf("%lld ",Ans[i-num]*qpow(a0,k)%p);
	return 0;
}
2024/12/3 14:12
加载中...