Gym 104901B TLE 求调
查看原帖
Gym 104901B TLE 求调
667558
_Kamisato_Ayaka_楼主2025/1/7 20:11

Problem

校内考的题,现在 TLE 90pts

拼尽全力无法调出。

#include <bits/stdc++.h>
#define int long long

using namespace std;
inline void read (int &x) {
    int res = 0, f = 1;
    char ch = getchar();
    while (!isdigit (ch)) f = ch == '-' ? -1 : 1, ch = getchar();
    while (isdigit (ch)) res = (res << 1) + (res << 3) + (ch ^ 48), ch = getchar();
    x = res * f;
}
const int MAXN = 1e5 + 10, mod = 998244353, MAXK = 711;
int T, n, m, dp[MAXN][MAXK], tmp[MAXN], siz[MAXN];
unordered_map <int,int> f[MAXN], tp;
vector <int> edge[MAXN];
inline void Addedge (int u, int v) { edge[u].push_back (v); }

inline void init() { for (int i = 1; i <= n; i ++) edge[i].clear(), siz[i] = 0; }
inline void init1() { f[1].clear(); }
inline void init2() {
    for (int i (1); i <= n; i ++) {
        for (int j (0); j <= m + 1; j ++) dp[i][j] = 0;
    }
}

void Dp1 (int u, int fa) {
    dp[u][1] = 1, siz[u] = 1;
    for (int v : edge[u]) {
        if (v == fa) continue;
        Dp1 (v, u);
        for (int j (0); j <= m + 1; j ++) tmp[j] = 0;
        for (int j (0); j <= min (siz[v], m); j ++) {
            for (int k (min (m + 1 - j, siz[u])); k >= 1; k --)
                tmp[k + j] = (tmp[k + j] + dp[u][k] * dp[v][j] % mod) % mod;
        }
        siz[u] += siz[v];
        for (int j (0); j <= m + 1; j ++) dp[u][j] = tmp[j];
    }
    dp[u][0] = (dp[u][m] + dp[u][m + 1]) % mod;
}

void Dp2 (int u, int fa) {
    f[u][1] = 1, siz[u] = 1;
    for (int v : edge[u]) {
        if (v == fa) continue;
        Dp2 (v, u);
        tp.clear();
        for (auto [j, tmpValb] : f[v]) {
            for (auto [k, tmpVala] : f[u]) {
                if (j + k > m + 1) continue;
                tp[j + k] = (tp[j + k] + tmpVala * tmpValb % mod) % mod;
            }
        }
        f[u].clear(), siz[u] += siz[v];
        for (auto [j, tmpVal] : tp) f[u][j] = tmpVal;
    }
    int ret (0);
    if (f[u].count (m))
        ret = (ret + f[u][m]) % mod;
    if (f[u].count (m + 1))
        ret = (ret + f[u][m + 1]) % mod;
    f[u][0] = (ret % mod + mod) % mod;
    if (!f[u][0]) f[u].erase (0);
    for (int v : edge[u]) { 
        if (v == fa) continue;
        f[v].clear();
    }
}

inline void Solve() {
    read (n), read (m), init();
    for (int i (1), u, v; i < n; i ++) {
        read (u), read (v);
        Addedge (u, v), Addedge (v, u);
    }
    int EdgeVal ((int)sqrt (n));
    if (m <= EdgeVal) {
        init2();
        Dp1 (1, 1);
        printf ("%lld\n", dp[1][0]);
    }
    else {
        init1();
        Dp2 (1, 1);
        printf ("%lld\n", f[1][0]);
    }
    return;
}

signed main() {
    read (T);
    while (T --)
        Solve();
    return 0;
}
2025/1/7 20:11
加载中...