555,爆6分,求助
#include <bits/stdc++.h>
using namespace std;
const int N = 50005;
int n, k, d[N], f[N][25], sum[N], ans;
vector <int> E[N];
void init(int x, int fa) {
for (int i = 0; i < E[x].size(); ++i) {
if (E[x][i] == fa) continue;
d[E[x][i]] = d[x] + 1;
f[E[x][i]][0] = x;
init(E[x][i], x);
}
}
int query(int u, int v) {
if (d[u] > d[v]) swap(u, v);
int k = d[v] - d[u];
for (int i = 20; ~i; --i) {
if (k >= (1 << i)) k -= (1 << i), v = f[v][i];
if (k == 0) break;
}
if (u == v) return u;
for (int i = 20; ~i; --i)
if (f[u][i] != f[v][i]) u = f[u][i], v = f[v][i];
return f[u][0];
}
void dfs(int x, int fa) {
for (int i = 0; i < E[x].size(); ++i) {
if (E[x][i] == fa) continue;
dfs(E[x][i], x);
sum[x] += sum[E[x][i]];
}
// cout << x << ' ' << sum[x] << endl;
ans = max(ans, sum[x]);
}
int main() {
cin >> n >> k;
for (int i = 1; i < n; ++i) {
int u, v;
cin >> u >> v;
E[u].push_back(v), E[v].push_back(u);
}
d[1] = 1, f[1][0] = 1;
init(1, -1);
for (int j = 1; j <= 20; ++j)
for (int i = 1; i <= n; ++i)
f[i][j] = f[f[i - 1][j]][j - 1];
while (k--) {
int S, T;
cin >> S >> T;
int p = query(S, T);
++sum[S], ++sum[T], --sum[p], --sum[f[p][0]];
}
dfs(1, -1);
// printf("%d", sum[1]);
printf("%d", ans);
return 0;
}