问题二三WA了/ll
#include<bits/stdc++.h>
#define int long long
#define mod 998244353ll
using namespace std;
int T,n,a,b,c,inv2,inv6;
struct node{
int f1,f2,f3;
}ans;
int fpow(int x,int y){
int res=1;
while(y){
if(y&1)res=(res*x)%mod;
x=(x*x)%mod;
y>>=1;
}
return res;
}
int sum1(int x){
int res=x;
res=(res*(x+1))%mod;
res=(res*inv2)%mod;
return res;
}
int sum2(int x){
int res=x;
res=(res*(x+1))%mod;
res=(res*(x+x+1))%mod;
res=(res*inv6)%mod;
return res;
}
node query(int n,int a,int b,int c){
node res,t1,t2;
res.f1=res.f2=res.f3=0;
int tmp;
if(a==0){
res.f1=((n+1)*(b/c))%mod;
res.f2=((n+1)*(b/c))%mod;
res.f2=(res.f2*(b/c))%mod;
res.f3=(sum1(n)*(b/c))%mod;
//cout<<res.f1<<' '<<res.f2<<' '<<res.f3<<' '<<a<<' '<<b<<' '<<c<<endl;
return res;
}
if(a>=c||b>=c){
t1=query(n,a%c,b%c,c);
tmp=((n+1)*(b/c))%mod;
res.f1=(res.f1+tmp)%mod;
res.f1=(res.f1+sum1(n)*(a/c))%mod;
res.f1=(res.f1+t1.f1)%mod;
res.f2=(res.f2+(b/c)*2*t1.f1)%mod;
res.f2=(res.f2+(a/c)*2*t1.f3)%mod;
res.f2=(res.f2+t1.f2)%mod;
tmp=((a/c)*(a/c))%mod;
res.f2=(res.f2+tmp*sum2(n))%mod;
tmp=((b/c)*(b/c))%mod;
res.f2=(res.f2+tmp*(n+1))%mod;
tmp=((b/c)*(a/c))%mod;
res.f2=(res.f2+tmp*sum2(n)*2)%mod;
res.f3=(res.f3+t1.f3)%mod;
res.f3=(res.f3+sum1(n)*(b/c))%mod;
res.f3=(res.f3+sum2(n)*(a/c))%mod;
return res;
}
tmp=(a*n+b)/c;
t1=query(tmp-1,c,c-b-1,a);
res.f1=(res.f1-t1.f1+n*tmp+mod)%mod;
tmp%=mod;
tmp=(tmp*(tmp+1))%mod;
res.f2=(res.f2+tmp*n)%mod;
res.f2=(res.f2-t1.f1+mod-t1.f1+mod)%mod;
res.f2=(res.f2-t1.f3+mod-t1.f3+mod)%mod;
res.f2=(res.f2-res.f1+mod)%mod;
tmp=(a*n+b)/c;
tmp%=mod;
res.f3=(res.f3+tmp*sum1(n)*2)%mod;
res.f3=(res.f3-t1.f1-t1.f2+mod*2)%mod;
res.f3=(res.f3*inv2)%mod;
return res;
}
signed main(){
cin>>T;
inv2=fpow(2,mod-2);
inv6=fpow(6,mod-2);
while(T--){
cin>>n>>a>>b>>c;
ans=query(n,a,b,c);
cout<<ans.f1<<' '<<ans.f2<<' '<<ans.f3<<endl;
}
}