树上差分模板求查错
查看原帖
树上差分模板求查错
56856
youngAnswer楼主2021/11/8 20:13

rt

#include <bits/stdc++.h>
#define pb push_back
using namespace std;
const int N = 5e4 + 5;
int n, a, b, k, x, y, lg[30], ans, cnt[N], fa[N][30], dep[N], vis[N];
map <pair<int, int>, int> mp;
vector <int> v[N];

inline void dfs(int now){	
	for (int i = 0; i < v[now].size(); ++i){
		int nx = v[now][i]; 
		if (fa[now][0] != nx){
			fa[nx][0] = now;
			dep[nx] = dep[now] + 1;
			dfs(nx);
		}
	}
}

inline void dfs2(int x, int last){
	for (int i = 0; i < v[x].size(); ++i){
		int nx = v[x][i];
		if (nx == last) continue;
		dfs2(nx, x);
		cnt[x] += cnt[nx];
	}
	ans = max(cnt[x], ans);
}

inline int lca(int x, int y){
	if (dep[x] < dep[y]) swap(x, y);
	for (int i = lg[dep[x] - dep[y]]; i >= 0; --i){
		if (dep[x] - dep[y] >> i & 1) x = fa[x][i];
	}
	
	if (x == y) return x;
	for (int i = lg[dep[x]]; i >= 0; --i){
		if (fa[x][i] != fa[y][i]) x = fa[x][i], y = fa[y][i];
	}
	
	return fa[x][0];
}
int main(){
	scanf("%d%d", &n, &k);
	for (int i = 1; i < n; ++i){
		scanf("%d%d", &a, &b);
		v[a].push_back(b);
		v[b].push_back(a);
	}
	dfs(1);
	for (int i = 1; i <= n; ++i) lg[i] = lg[i - 1] + (1 << lg[i - 1] + 1 == i);
	for (int i = 1; i <= lg[n]; ++i){
		for (int j = 1; j <= n; ++j){
			fa[j][i] = fa[fa[j][i - 1]][i - 1];
		}
	}
	
	for (int i = 1; i <= k; ++i){
		scanf("%d%d", &a, &b);
		cnt[a] ++, cnt[b] ++;
		int c = lca(a, b);
		cnt[c] --;
		cnt[fa[c][0]] --;
	}
	dfs2(1, 0);
	printf("%d\n", ans);
	return 0;
}
2021/11/8 20:13
加载中...