校内考的题,现在 TLE 90pts
拼尽全力无法调出。
#include <bits/stdc++.h>
#define int long long
using namespace std;
inline void read (int &x) {
int res = 0, f = 1;
char ch = getchar();
while (!isdigit (ch)) f = ch == '-' ? -1 : 1, ch = getchar();
while (isdigit (ch)) res = (res << 1) + (res << 3) + (ch ^ 48), ch = getchar();
x = res * f;
}
const int MAXN = 1e5 + 10, mod = 998244353, MAXK = 711;
int T, n, m, dp[MAXN][MAXK], tmp[MAXN], siz[MAXN];
unordered_map <int,int> f[MAXN], tp;
vector <int> edge[MAXN];
inline void Addedge (int u, int v) { edge[u].push_back (v); }
inline void init() { for (int i = 1; i <= n; i ++) edge[i].clear(), siz[i] = 0; }
inline void init1() { f[1].clear(); }
inline void init2() {
for (int i (1); i <= n; i ++) {
for (int j (0); j <= m + 1; j ++) dp[i][j] = 0;
}
}
void Dp1 (int u, int fa) {
dp[u][1] = 1, siz[u] = 1;
for (int v : edge[u]) {
if (v == fa) continue;
Dp1 (v, u);
for (int j (0); j <= m + 1; j ++) tmp[j] = 0;
for (int j (0); j <= min (siz[v], m); j ++) {
for (int k (min (m + 1 - j, siz[u])); k >= 1; k --)
tmp[k + j] = (tmp[k + j] + dp[u][k] * dp[v][j] % mod) % mod;
}
siz[u] += siz[v];
for (int j (0); j <= m + 1; j ++) dp[u][j] = tmp[j];
}
dp[u][0] = (dp[u][m] + dp[u][m + 1]) % mod;
}
void Dp2 (int u, int fa) {
f[u][1] = 1, siz[u] = 1;
for (int v : edge[u]) {
if (v == fa) continue;
Dp2 (v, u);
tp.clear();
for (auto [j, tmpValb] : f[v]) {
for (auto [k, tmpVala] : f[u]) {
if (j + k > m + 1) continue;
tp[j + k] = (tp[j + k] + tmpVala * tmpValb % mod) % mod;
}
}
f[u].clear(), siz[u] += siz[v];
for (auto [j, tmpVal] : tp) f[u][j] = tmpVal;
}
int ret (0);
if (f[u].count (m))
ret = (ret + f[u][m]) % mod;
if (f[u].count (m + 1))
ret = (ret + f[u][m + 1]) % mod;
f[u][0] = (ret % mod + mod) % mod;
if (!f[u][0]) f[u].erase (0);
for (int v : edge[u]) {
if (v == fa) continue;
f[v].clear();
}
}
inline void Solve() {
read (n), read (m), init();
for (int i (1), u, v; i < n; i ++) {
read (u), read (v);
Addedge (u, v), Addedge (v, u);
}
int EdgeVal ((int)sqrt (n));
if (m <= EdgeVal) {
init2();
Dp1 (1, 1);
printf ("%lld\n", dp[1][0]);
}
else {
init1();
Dp2 (1, 1);
printf ("%lld\n", f[1][0]);
}
return;
}
signed main() {
read (T);
while (T --)
Solve();
return 0;
}