WA on #21 求助
查看原帖
WA on #21 求助
53930
Lstdo楼主2020/11/23 21:03

主要参考了这篇博客,写法比较类似。下文中的 std 指这篇博客中的代码。

目前进展:

  1. 构造了一组会发生乘 00 的小(指 n=70n=70) 数据通过。
  2. 各种范围的随机数据对拍通过。
  3. 在 #21 下,输出最终影响答案的 f(i,L),f(i,L1),g(i,L)f(i,L),f(i,L-1),g(i,L),有且仅有 f(1,L)f(1,L)f(1,L1)f(1,L-1) 和 std 不同(甚至所有 gg 都一样),注释计算 GG 的部分没有结果,不太可能是越界,所以极大概率是计算 ff 的时候有问题。
  4. 输出了进入和离开 dfs 时的 55 个标记、所有儿子顺序,均与 std 一致。

哪位大佬能帮忙看看啊 /kk

保留了部分调试痕迹。

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cctype>
#include <vector>
#include <utility>
#include <list>
#include <algorithm>
#define MAXN 1000005
using namespace std;
inline int read()
{
	int ans=0;
	char c=getchar();
	while (!isdigit(c)) c=getchar();
	while (isdigit(c)) ans=(ans<<3)+(ans<<1)+(c^48),c=getchar();
	return ans;
}
const int MOD=998244353;
typedef long long ll;
inline int add(const int& x,const int& y){return x+y>=MOD? x+y-MOD:x+y;}
inline int dec(const int& x,const int& y){return x<y? x-y+MOD:x-y;}
inline int qpow(int a,int p)
{
	int ans=1;
	while (p)
	{
		if (p&1) ans=(ll)ans*a%MOD;
		a=(ll)a*a%MOD,p>>=1;
	}
	return ans;
}
vector<int> T[MAXN],e[MAXN];
int fa[MAXN],son[MAXN],mx[MAXN],s[MAXN],sinv[MAXN],n,L,k;
void dfs(int u,int f)
{
	fa[u]=f,s[u]=1;
	for (int i=0;i<(int)T[u].size();i++)
		if (T[u][i]!=f)
		{
			dfs(T[u][i],u);
			if (mx[T[u][i]]>mx[son[u]]) son[u]=T[u][i];
			s[u]=(ll)s[u]*s[T[u][i]]%MOD;
		}
	mx[u]=mx[son[u]]+1;
	s[u]=add(s[u],1);
}
int fac[MAXN],finv[MAXN]; 
inline bool cmp(const int& x,const int& y){return mx[x]>mx[y];}
inline void init()
{
	fac[0]=1;
	for (int i=1;i<=n;i++)
		if (s[i]) fac[i]=(ll)fac[i-1]*s[i]%MOD;
		else fac[i]=fac[i-1];
	finv[n]=qpow(fac[n],MOD-2);
	for (int i=n-1;i>=1;i--)
		if (s[i+1]) finv[i]=(ll)finv[i+1]*s[i+1]%MOD;
		else finv[i]=finv[i+1];
	for (int i=1;i<=n;i++) if (s[i]) sinv[i]=(ll)finv[i]*fac[i-1]%MOD;
	for (int i=1;i<=n;i++) stable_sort(e[i].begin(),e[i].end(),cmp);
}
void dfs(int u)
{
	if (son[u]) dfs(son[u]);
	for (int i=0;i<(int)T[u].size();i++)
		if (T[u][i]!=fa[u]&&T[u][i]!=son[u])
			e[u].push_back(T[u][i]),dfs(T[u][i]);
}
int F1[MAXN],F2[MAXN],G1[MAXN];
struct BackDS
{
	typedef pair<int*,int> pi;
	list<pi> his;
	inline void modify(int& x,int v){his.push_back(make_pair(&x,x)),x=v;}
	inline void undo(){while (!his.empty()) *his.back().first=his.back().second,his.pop_back();}
}q[MAXN];
namespace F
{
	int buf[MAXN<<5],*cur=buf;
	int* dp[MAXN];
	inline int* newbuf(int x){int* p=cur;cur+=2*(x+5);return p;}
	int mul[MAXN],inv[MAXN],pls[MAXN],lim[MAXN],val[MAXN];
	inline int calc(int u,int i)
	{
		i=min(i,mx[u]-1);
		if (i<lim[u]) return ((ll)mul[u]*dp[u][i]+pls[u])%MOD;
		return ((ll)mul[u]*val[u]+pls[u])%MOD;
	}
	inline int clac(int u,int v){return (ll)dec(v,pls[u])*inv[u]%MOD;}
	void dfs(int u)
	{
		printf("%d in:%d %d %d %d %d\n",u,mul[u],inv[u],pls[u],lim[u],val[u]);
//		if (u==100969)
//		{
//			cerr<<"debug";
//		}
		if (son[u])
		{
			dp[son[u]]=dp[u]+1,dfs(son[u]);
			mul[u]=mul[son[u]],inv[u]=inv[son[u]],pls[u]=pls[son[u]],lim[u]=lim[son[u]]+1,val[u]=val[son[u]];
			dp[u][0]=clac(u,1);
		}
		else
		{
			mul[u]=inv[u]=1;pls[u]=2;lim[u]=n+1;
			F1[u]=calc(u,L),F2[u]=calc(u,L-1);
			return;
		}
		int las=0;
		for (int k=0;k<(int)e[u].size();k++)
		{
			int v=las=e[u][k];
			dp[v]=newbuf(mx[v]),dfs(v);
			for (int i=0;i<=mx[v];i++)
			{
				if (i==lim[u]) q[v].modify(dp[u][i],val[lim[u]]),q[v].modify(lim[u],lim[u]+1);
				q[v].modify(dp[u][i],clac(u,(ll)calc(u,i)*(i? calc(v,i-1):1)%MOD));
			}
			if (s[v])
			{
				q[v].modify(mul[u],(ll)mul[u]*s[v]%MOD);
				q[v].modify(pls[u],(ll)pls[u]*s[v]%MOD);
				q[v].modify(inv[u],(ll)inv[u]*sinv[v]%MOD);
				for (int i=0;i<=mx[v];i++) q[v].modify(dp[u][i],clac(u,(ll)sinv[v]*calc(u,i)%MOD));
			}
			else q[v].modify(lim[u],mx[v]+1),q[v].modify(val[u],clac(u,0));
		}
		if (las) q[las].modify(pls[u],add(pls[u],1));
		else pls[u]=add(pls[u],1);
		F1[u]=calc(u,L),F2[u]=calc(u,L-1);
		printf("%d out:%d %d %d %d %d\n",u,mul[u],inv[u],pls[u],lim[u],val[u]);
	}
	inline void solve(){dp[1]=newbuf(mx[1]),dfs(1);}
}
namespace G
{
	int buf[MAXN<<1],pre[MAXN],*cur=buf;
	int* dp[MAXN];
	inline int* newbuf(int x){cur+=x;int* p=cur;cur+=x;return p;}
	int mul[MAXN],inv[MAXN],pls[MAXN],lim[MAXN],val[MAXN];
	inline int calc(int u,int i)
	{
		if (i<lim[u]) return ((ll)mul[u]*dp[u][i]+pls[u])%MOD;
		return ((ll)mul[u]*val[u]+pls[u])%MOD;
	}
	inline int clac(int u,int v){return (ll)dec(v,pls[u])*inv[u]%MOD;}
	void dfs(int u)
	{
		G1[u]=calc(u,L);
		pre[0]=1;
		int pos=1,cur=1,cinv=1;
		for (int k=(int)e[u].size()-1;k>=0;k--)
		{
			int v=e[u][k];
			q[v].undo();
			dp[v]=newbuf(mx[v])-max(0,L-mx[v]+1);
			mul[v]=inv[v]=1,lim[v]=L+1;
			for (int i=max(0,L-mx[v]+1);i<=L;i++)
			{
				int t=1;
				if (i) t=(ll)t*calc(u,i-1)%MOD;
				if (i>1)
				{
					t=(ll)t*F::calc(u,i-1)%MOD;
					if (i-2<pos) t=(ll)t*pre[i-2]%MOD;
					else t=(ll)t*cur%MOD;	
				}
				dp[v][i]=clac(v,t);
			}
			pls[v]=add(pls[v],1);
			if (L-mx[v]+1<=0) dp[v][0]=clac(v,1);
			for (int i=0;i<=mx[v];i++)
			{
				if (i<pos) pre[i]=(ll)pre[i]*F::calc(v,i)%MOD;
				else pre[i]=(ll)cur*F::calc(v,i)%MOD;
			}
			pos=mx[v]+1;
			cur=(ll)cur*s[v]%MOD,cinv=(ll)cinv*sinv[v]%MOD;
		}
		int v=son[u];
		if (v)
		{
			mul[v]=mul[u],inv[v]=inv[u],pls[v]=pls[u],lim[v]=lim[u]+1,val[v]=val[u];
			dp[v]=dp[u]-1;
			for (int i=max(2,L-mx[v]+1);i<=pos+1;i++)
			{
				if (i==lim[v]) dp[v][lim[v]++]=val[v];
				dp[v][i]=clac(v,(ll)calc(v,i)*pre[i-2]%MOD);	
			} 
			if (cur)
			{
				mul[v]=(ll)mul[v]*cur%MOD;
				pls[v]=(ll)pls[v]*cur%MOD;
				inv[v]=(ll)inv[v]*cinv%MOD;
				for (int i=max(0,L-mx[v]+1);i<=pos+1;i++) dp[v][i]=clac(v,(ll)calc(v,i)*cinv%MOD);
			}
			else lim[v]=pos+1,val[v]=clac(v,0);
			pls[v]=add(pls[v],1);
			if (L-mx[v]+1<=0) dp[v][0]=clac(v,1);
			dfs(v);
		}
		for (int i=0;i<(int)e[u].size();i++) dfs(e[u][i]);
	}
	inline void solve(){dp[1]=newbuf(mx[1])-max(L-mx[1]+1,0),mul[1]=inv[1]=pls[1]=1,lim[1]=L+1,dfs(1);}
}
int main()
{
	freopen("test.in","r",stdin);
	freopen("test.out","w",stdout);
	n=read(),L=read(),k=read();
	if (!L) return printf("%d\n",n),0;
	for (int i=1;i<n;i++) 
	{
		int u,v;
		u=read(),v=read();
		T[u].push_back(v),T[v].push_back(u);
	}
	dfs(1,0),dfs(1);
	init();
	for (int i=1;i<=n;i++) if (s[i]==0) cerr<<i<<'\n';
	F::solve();
	G::solve();
	for (int i=1;i<=n;i++) printf("%d %d %d %d\n",son[i],F1[i],F2[i],G1[i]);
	int ans=0;
	for (int i=1;i<=n;i++)
	{
		ans=add(ans,qpow((ll)dec(F1[i],1)*G1[i]%MOD,k));
		if (i>1) ans=dec(ans,qpow((ll)dec(F2[i],1)*dec(G1[i],1)%MOD,k));
	}
	cout<<ans;
	return 0;
}
2020/11/23 21:03
加载中...