萌新刚学 OI,求助卡常
查看原帖
萌新刚学 OI,求助卡常
350270
CatFromMars楼主2024/9/25 14:40

如题。TLE13 1.1s

思路就是二分,应该是哪里写假了,请巨佬帮忙看看哪里复杂度假了/哪里可以优化 qwq

#include <bits/stdc++.h>
using namespace std;
const int N = 3e5;
inline int read()
{
	int x=0,f=1;char ch=getchar();
	while (ch<'0'||ch>'9'){if (ch=='-') f=-1;ch=getchar();}
	while (ch>='0'&&ch<='9'){x=x*10+ch-48;ch=getchar();}
	return x*f;
}
struct node {
	int to, val;
	node(int T, int V) {
		to = T, val = V;
	}
};
vector <node> gra[N + 10];
void link(int x, int y, int z) {
	gra[x].push_back(node(y, z));
	gra[y].push_back(node(x, z));
}

int jp[N + 10][25], dep[N + 10], sum[N + 10][25];
void dfs(int u, int fa, int w) {
	dep[u] = dep[fa] + 1;
	jp[u][0] = fa;
	for(int i = 1; i <= 19; i++)
		jp[u][i] = jp[jp[u][i - 1]][i - 1];
	sum[u][0] = w;
	for(int i = 1; i <= 19; i++)
		sum[u][i] = sum[u][i - 1] + sum[jp[u][i - 1]][i - 1];
		
	for(int i = 0; i < gra[u].size(); i++) {
		int v = gra[u][i].to, w = gra[u][i].val;
		if(v == fa) continue;
		dfs(v, u, w);
	}
}

struct qry {
	int u, v;
	int l, w;
} Q[N + 10];
qry lca(int u, int v) {
	int tmpa = u, tmpb = v;
	int s = 0;
	if(dep[u] < dep[v]) swap(u, v);
	for(int i = 19; i >= 0; i--)
		if(dep[jp[u][i]] >= dep[v]) {
			s += sum[u][i];
			u = jp[u][i];
		}
	if(u == v) return (qry){tmpa, tmpb, u, s};
	for(int i = 19; i >= 0; i--)
		if(jp[u][i] != jp[v][i]) {
			s += sum[u][i], s += sum[v][i];
			u = jp[u][i], v = jp[v][i];
		}
	s += sum[u][0] + sum[v][0];
	return (qry){tmpa, tmpb, jp[u][0], s};
}

int n, m;

int e[N + 10], ee[N + 10];
void gete(int u, int fa) {
	for(int i = 0; i < gra[u].size(); i++) {
		int v = gra[u][i].to;
		if(v == fa) continue;
		gete(v, u);
		e[u] += e[v];
	}
}
bool check(int x) {
	for(int i = 1; i <= n; i++)
		e[i] = 0;
	
	int del = -1, cnt = 0;
	for(int i = 1; i <= m; i++) {
		if(Q[i].w > x) {
			del = max(del, Q[i].w - x);
			e[Q[i].u]++, e[Q[i].v]++;
			e[Q[i].l] -= 2;
			cnt++;
		}
	}
	gete(1, 0);
	
	for(int i = 1; i <= n; i++)
		if(sum[i][0] >= del && e[i] >= cnt) 
			return 1;
	return 0;
}
int main() {
	n = read(), m = read();
	int W = 0;
	for(int i = 1, x, y, z; i < n; i++) {
		x = read(), y = read(), z = read();
		W = max(W, z);
		link(x, y, z);
	}
	dfs(1, 0, 0);
	
	int maxn = 0;
	for(int i = 1, x, y; i <= m; i++) {
		x = read(), y = read();
		Q[i] = lca(x, y);
		maxn = max(maxn, Q[i].w);
	}
	
	int L = max(0, maxn - W), R = maxn, ans = 0;
	while(L <= R) {
		int mid = (L + R) >> 1;
		if(check(mid)) ans = mid, R = mid - 1;
		else L = mid + 1;
	}
	printf("%d\n", ans);
}
2024/9/25 14:40
加载中...