代码
#include<bits/stdc++.h>
using namespace std;
#define ll long long
ll n,m,mod=998244353;
ll ans=0,h[300010],fa[300010][30],pre[300010][55];
vector<ll> a[300010];
inline ll read()
{
ll x = 0, f = 1;
char c = getchar();
while (c < '0' || c > '9'){ if (c == '-') f = -1;c = getchar();}
while (c >= '0' && c <= '9'){x = x * 10 + (c - '0');c = getchar();}
return x * f;
}
ll get(ll a,ll b)
{
ll ret=1;
while(b)
{
if(b&1) ret=ret*a%mod;
a=a*a%mod;
b>>=1;
}
return ret%mod;
}
void dfs(ll x,ll father)
{
h[x]=h[father]+1;
fa[x][0]=father;
for (int i=1;i<=50;i++)
pre[x][i]=(pre[father][i]+get(h[x],i)+mod)%mod;
for (ll i=0;i<a[x].size();i++)
{
if (a[x][i]==father) continue;
dfs(a[x][i],x);
}
}
ll lca(ll x,ll y)
{
if (h[y]>h[x]) swap(x,y);
ll cha=h[x]-h[y];
if (cha)
for (ll i=log2(cha);i>=0;i--)
if ((1<<i)<=cha)
{
cha-=(1<<i);
x=fa[x][i];
}
if (x==y) return x;
for (ll i=log2(h[x]);i>=0;i--)
if (fa[x][i]!=fa[y][i])
{
x=fa[x][i];
y=fa[y][i];
}
return fa[x][0];
}
int main()
{
n=read();
for (ll i=1;i<n;i++)
{
ll x=read(),y=read();
a[x].push_back(y);
a[y].push_back(x);
}
h[1]=-1;
dfs(1,1);
for (ll i=1;(i<<i)<=n;i++)
for (ll j=1;j<=n;j++)
fa[j][i]=fa[fa[j][i-1]][i-1];
m=read();
for (int i=1;i<=m;i++)
{
ll x=read(),y=read(),k=read();
ll root=lca(x,y);
int sum1=(pre[x][k]+pre[y][k]+mod)%mod;
int sum2=(pre[root][k]+pre[fa[root][0]][k]+mod)%mod;
printf("%lld\n",((sum1-sum2)+mod)%mod);
}
return 0;
}