主要参考了这篇博客,写法比较类似。下文中的 std 指这篇博客中的代码。
目前进展:
哪位大佬能帮忙看看啊 /kk
保留了部分调试痕迹。
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cctype>
#include <vector>
#include <utility>
#include <list>
#include <algorithm>
#define MAXN 1000005
using namespace std;
inline int read()
{
int ans=0;
char c=getchar();
while (!isdigit(c)) c=getchar();
while (isdigit(c)) ans=(ans<<3)+(ans<<1)+(c^48),c=getchar();
return ans;
}
const int MOD=998244353;
typedef long long ll;
inline int add(const int& x,const int& y){return x+y>=MOD? x+y-MOD:x+y;}
inline int dec(const int& x,const int& y){return x<y? x-y+MOD:x-y;}
inline int qpow(int a,int p)
{
int ans=1;
while (p)
{
if (p&1) ans=(ll)ans*a%MOD;
a=(ll)a*a%MOD,p>>=1;
}
return ans;
}
vector<int> T[MAXN],e[MAXN];
int fa[MAXN],son[MAXN],mx[MAXN],s[MAXN],sinv[MAXN],n,L,k;
void dfs(int u,int f)
{
fa[u]=f,s[u]=1;
for (int i=0;i<(int)T[u].size();i++)
if (T[u][i]!=f)
{
dfs(T[u][i],u);
if (mx[T[u][i]]>mx[son[u]]) son[u]=T[u][i];
s[u]=(ll)s[u]*s[T[u][i]]%MOD;
}
mx[u]=mx[son[u]]+1;
s[u]=add(s[u],1);
}
int fac[MAXN],finv[MAXN];
inline bool cmp(const int& x,const int& y){return mx[x]>mx[y];}
inline void init()
{
fac[0]=1;
for (int i=1;i<=n;i++)
if (s[i]) fac[i]=(ll)fac[i-1]*s[i]%MOD;
else fac[i]=fac[i-1];
finv[n]=qpow(fac[n],MOD-2);
for (int i=n-1;i>=1;i--)
if (s[i+1]) finv[i]=(ll)finv[i+1]*s[i+1]%MOD;
else finv[i]=finv[i+1];
for (int i=1;i<=n;i++) if (s[i]) sinv[i]=(ll)finv[i]*fac[i-1]%MOD;
for (int i=1;i<=n;i++) stable_sort(e[i].begin(),e[i].end(),cmp);
}
void dfs(int u)
{
if (son[u]) dfs(son[u]);
for (int i=0;i<(int)T[u].size();i++)
if (T[u][i]!=fa[u]&&T[u][i]!=son[u])
e[u].push_back(T[u][i]),dfs(T[u][i]);
}
int F1[MAXN],F2[MAXN],G1[MAXN];
struct BackDS
{
typedef pair<int*,int> pi;
list<pi> his;
inline void modify(int& x,int v){his.push_back(make_pair(&x,x)),x=v;}
inline void undo(){while (!his.empty()) *his.back().first=his.back().second,his.pop_back();}
}q[MAXN];
namespace F
{
int buf[MAXN<<5],*cur=buf;
int* dp[MAXN];
inline int* newbuf(int x){int* p=cur;cur+=2*(x+5);return p;}
int mul[MAXN],inv[MAXN],pls[MAXN],lim[MAXN],val[MAXN];
inline int calc(int u,int i)
{
i=min(i,mx[u]-1);
if (i<lim[u]) return ((ll)mul[u]*dp[u][i]+pls[u])%MOD;
return ((ll)mul[u]*val[u]+pls[u])%MOD;
}
inline int clac(int u,int v){return (ll)dec(v,pls[u])*inv[u]%MOD;}
void dfs(int u)
{
printf("%d in:%d %d %d %d %d\n",u,mul[u],inv[u],pls[u],lim[u],val[u]);
// if (u==100969)
// {
// cerr<<"debug";
// }
if (son[u])
{
dp[son[u]]=dp[u]+1,dfs(son[u]);
mul[u]=mul[son[u]],inv[u]=inv[son[u]],pls[u]=pls[son[u]],lim[u]=lim[son[u]]+1,val[u]=val[son[u]];
dp[u][0]=clac(u,1);
}
else
{
mul[u]=inv[u]=1;pls[u]=2;lim[u]=n+1;
F1[u]=calc(u,L),F2[u]=calc(u,L-1);
return;
}
int las=0;
for (int k=0;k<(int)e[u].size();k++)
{
int v=las=e[u][k];
dp[v]=newbuf(mx[v]),dfs(v);
for (int i=0;i<=mx[v];i++)
{
if (i==lim[u]) q[v].modify(dp[u][i],val[lim[u]]),q[v].modify(lim[u],lim[u]+1);
q[v].modify(dp[u][i],clac(u,(ll)calc(u,i)*(i? calc(v,i-1):1)%MOD));
}
if (s[v])
{
q[v].modify(mul[u],(ll)mul[u]*s[v]%MOD);
q[v].modify(pls[u],(ll)pls[u]*s[v]%MOD);
q[v].modify(inv[u],(ll)inv[u]*sinv[v]%MOD);
for (int i=0;i<=mx[v];i++) q[v].modify(dp[u][i],clac(u,(ll)sinv[v]*calc(u,i)%MOD));
}
else q[v].modify(lim[u],mx[v]+1),q[v].modify(val[u],clac(u,0));
}
if (las) q[las].modify(pls[u],add(pls[u],1));
else pls[u]=add(pls[u],1);
F1[u]=calc(u,L),F2[u]=calc(u,L-1);
printf("%d out:%d %d %d %d %d\n",u,mul[u],inv[u],pls[u],lim[u],val[u]);
}
inline void solve(){dp[1]=newbuf(mx[1]),dfs(1);}
}
namespace G
{
int buf[MAXN<<1],pre[MAXN],*cur=buf;
int* dp[MAXN];
inline int* newbuf(int x){cur+=x;int* p=cur;cur+=x;return p;}
int mul[MAXN],inv[MAXN],pls[MAXN],lim[MAXN],val[MAXN];
inline int calc(int u,int i)
{
if (i<lim[u]) return ((ll)mul[u]*dp[u][i]+pls[u])%MOD;
return ((ll)mul[u]*val[u]+pls[u])%MOD;
}
inline int clac(int u,int v){return (ll)dec(v,pls[u])*inv[u]%MOD;}
void dfs(int u)
{
G1[u]=calc(u,L);
pre[0]=1;
int pos=1,cur=1,cinv=1;
for (int k=(int)e[u].size()-1;k>=0;k--)
{
int v=e[u][k];
q[v].undo();
dp[v]=newbuf(mx[v])-max(0,L-mx[v]+1);
mul[v]=inv[v]=1,lim[v]=L+1;
for (int i=max(0,L-mx[v]+1);i<=L;i++)
{
int t=1;
if (i) t=(ll)t*calc(u,i-1)%MOD;
if (i>1)
{
t=(ll)t*F::calc(u,i-1)%MOD;
if (i-2<pos) t=(ll)t*pre[i-2]%MOD;
else t=(ll)t*cur%MOD;
}
dp[v][i]=clac(v,t);
}
pls[v]=add(pls[v],1);
if (L-mx[v]+1<=0) dp[v][0]=clac(v,1);
for (int i=0;i<=mx[v];i++)
{
if (i<pos) pre[i]=(ll)pre[i]*F::calc(v,i)%MOD;
else pre[i]=(ll)cur*F::calc(v,i)%MOD;
}
pos=mx[v]+1;
cur=(ll)cur*s[v]%MOD,cinv=(ll)cinv*sinv[v]%MOD;
}
int v=son[u];
if (v)
{
mul[v]=mul[u],inv[v]=inv[u],pls[v]=pls[u],lim[v]=lim[u]+1,val[v]=val[u];
dp[v]=dp[u]-1;
for (int i=max(2,L-mx[v]+1);i<=pos+1;i++)
{
if (i==lim[v]) dp[v][lim[v]++]=val[v];
dp[v][i]=clac(v,(ll)calc(v,i)*pre[i-2]%MOD);
}
if (cur)
{
mul[v]=(ll)mul[v]*cur%MOD;
pls[v]=(ll)pls[v]*cur%MOD;
inv[v]=(ll)inv[v]*cinv%MOD;
for (int i=max(0,L-mx[v]+1);i<=pos+1;i++) dp[v][i]=clac(v,(ll)calc(v,i)*cinv%MOD);
}
else lim[v]=pos+1,val[v]=clac(v,0);
pls[v]=add(pls[v],1);
if (L-mx[v]+1<=0) dp[v][0]=clac(v,1);
dfs(v);
}
for (int i=0;i<(int)e[u].size();i++) dfs(e[u][i]);
}
inline void solve(){dp[1]=newbuf(mx[1])-max(L-mx[1]+1,0),mul[1]=inv[1]=pls[1]=1,lim[1]=L+1,dfs(1);}
}
int main()
{
freopen("test.in","r",stdin);
freopen("test.out","w",stdout);
n=read(),L=read(),k=read();
if (!L) return printf("%d\n",n),0;
for (int i=1;i<n;i++)
{
int u,v;
u=read(),v=read();
T[u].push_back(v),T[v].push_back(u);
}
dfs(1,0),dfs(1);
init();
for (int i=1;i<=n;i++) if (s[i]==0) cerr<<i<<'\n';
F::solve();
G::solve();
for (int i=1;i<=n;i++) printf("%d %d %d %d\n",son[i],F1[i],F2[i],G1[i]);
int ans=0;
for (int i=1;i<=n;i++)
{
ans=add(ans,qpow((ll)dec(F1[i],1)*G1[i]%MOD,k));
if (i>1) ans=dec(ans,qpow((ll)dec(F2[i],1)*dec(G1[i],1)%MOD,k));
}
cout<<ans;
return 0;
}