树链剖分求LCA 50pts求助
查看原帖
树链剖分求LCA 50pts求助
846303
Stdinout楼主2024/10/19 07:44
#include <stdio.h>
#include <algorithm>
#include <string.h>
namespace LCA {
	struct edge {
		int to, ne, val;
	} e[600005];
	int n, m, s, ecnt, head[300005], dep[300005], siz[300005], son[300005], top[300005], f[300005];
	int dist[300005], val[300005];
	void add(int x, int y, int val) {
		e[++ecnt].to = y;
		e[ecnt].ne = head[x];
		e[ecnt].val = val;
		head[x] = ecnt;
	}
	void dfs1(int x, int fa) {
		siz[x] = 1;
		dep[x] = dep[f[x]] + 1;
		for (int i = head[x]; i; i = e[i].ne) {
			int y = e[i].to;
			if (y == fa) continue;
			if (y == f[x]) continue;
			f[y] = x;
			dist[y] = dist[x] + e[i].val;
			val[y] = e[i].to;
			dfs1(y, x);
			siz[x] += siz[y];
			if (!son[x] || siz[son[x]] < siz[y])
				son[x] = y;
		}
	}
	void dfs2(int x, int topn, int fa) {
		top[x] = topn;
		if (son[x])dfs2(son[x], topn, x);
		for (int i = head[x]; i; i = e[i].ne) {
			int y = e[i].to;
			if (y == fa) continue;
			if (y == f[x] || y == son[x]) continue;
			dfs2(y, y, x);
		}
	}
	int lca(int x, int y) {
		while (top[x] != top[y]) {
			if (dep[top[x]] >= dep[top[y]]) x = f[top[x]];
			else y = f[top[y]];
		}
		return (dep[x] < dep[y] ? x : y);
	}
};
struct ask {
	int u, v, dist, lca;
	bool operator < (const ask &o) const {
		return dist > o.dist;
	}
} Ask[300005];
int n, m;
int sum[300005];
inline int max(int x, int y) {
	return (x > y ? x : y);
}
inline int min(int x, int y) {
	return (x + y - max(x, y));
}
void dfs(int x, int fa) {
	for (int i = LCA::head[x]; i; i = LCA::e[i].ne) {
		int y = LCA::e[i].to;
		if (y == fa)continue;
		dfs(y, x);
		sum[x] += sum[y];
	}
}
bool check(int mid) {
	int cnt = 0, maxn = 0;
	memset(sum, 0, sizeof sum);
	for (int i = 1; i <= m; ++i) {
		if (Ask[i].dist <= mid) continue;
		sum[Ask[i].u] ++;
		sum[Ask[i].v] ++;
		sum[Ask[i].lca] -= 2;
		cnt ++;
	}
	dfs(1, -1);
	for (int i = 1; i <= n; i++) {
		if (sum[i] == cnt)
			maxn = max(maxn, LCA::val[i]);
	}
	return Ask[1].dist - maxn <= mid;
}
int Binary(int _l, int _r) {
	int l = _l, r = _r;
	int mid, ans;
	while (l <= r) {
		mid = (l + r) >> 1;
		if (check(mid)) {
			ans = mid;
			r = mid - 1;
		} else l = mid + 1;
	}
	return ans;
}
int main() {
	scanf("%d%d", &n, &m);
	int l = -1, r = -1;
	for (int i = 1; i <= n - 1; ++i) {
		int u, v, val;
		scanf("%d%d%d", &u, &v, &val);
		LCA::add(u, v, val);
		LCA::add(v, u, val);
		l = max(l, val);
	}
	LCA::dfs1(1, -1);
	LCA::dfs2(1, 1, -1);
	for (int i = 1; i <= m; ++i) {
		int u, v;
		scanf("%d%d", &u, &v);
		Ask[i].u = u, Ask[i].v = v;
		Ask[i].lca = (LCA::lca(u, v));
		Ask[i].dist = (LCA::dist[u] + LCA::dist[v] - 2 * (LCA::dist[Ask[i].lca]));
		r = max(r, Ask[i].dist);
	}
	std::sort(Ask + 1, Ask + m + 1);
	l = r - l;
	printf("%d\n", Binary(l, r));
	return 0;
}
2024/10/19 07:44
加载中...