只 AC #11树剖求条
查看原帖
只 AC #11树剖求条
941431
Autream楼主2024/11/8 19:46
// Problem: P4315 月下“毛景树”
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P4315
// Memory Limit: 128 MB
// Time Limit: 1000 ms
// Date: 2024/11/08 17:11:04
// Author: Li_Feiy

#include <bits/stdc++.h>
#define arrout(a, n) rep(i, 1, n) printk(a[i])
#define arrin(a, n) rep(i, 1, n) a[i] = read()
#define rep(i, x, n) for(int i = x; i <= n; i++)
#define dep(i, x, n) for(int i = x; i >= n; i--)
#define erg(i, x) for(int i = head[x]; i; i = e[i].nex)
#define dbg(x) std::cout << #x << ":" << x << " "
#define mem(a, x) memset(a, x, sizeof a)
#define all(x) x.begin(), x.end()
#define arrall(a, n) a + 1, a + 1 + n
#define PII std::pair<int, int>
#define m_p std::make_pair
#define u_b upper_bound
#define l_b lower_bound
#define p_b push_back
#define CD const double
#define CI const int
#define int long long
#define il inline
#define ss second
#define ff first
#define itn int
int read() {
	char ch = getchar();
	int r = 0, w = 1;
	while(ch < '0' || ch > '9') w = ch == '-' ? -1 : w, ch = getchar();
	while(ch >= '0' && ch <= '9') r = (r << 3) + (r << 1) + (ch ^ 48), ch = getchar();
	return r * w;
}

void print(int x) {
	if(x < 0) putchar('-'), x = -x;
	if(x >= 10) print(x / 10);
	putchar(x % 10 + '0');
}template<typename ...Args>
void print(int t, Args... args) { print(t), print(args...); }

void printl(int x) { print(x), putchar('\n'); }
template<typename ...Args>
void printl(int t, Args... args) { printl(t), printl(args...); }

void printk(int x) { print(x), putchar(' '); }
template<typename ...Args>
void printk(int t, Args ... args) { printk(t), printk(args...); }

CI N = 2e5 + 5;
int n, q, cnt, tot, a[N], w[N], fa[N], dfn[N], son[N], top[N], deep[N], head[N], size[N];
struct edge {
	int to, nex, data;
} e[N << 1];
struct node {
	int x, y, data;
} b[N];
struct Segment_Tree {
#define ls k << 1
#define rs k << 1 | 1
	int s[N << 2], lazya[N << 2], lazyc[N << 2];
	void pushup(int k) { s[k] = std::max(s[ls], s[rs]); }
	void pushdown(int k, int l, int r) {
		if(lazya[k]) {
			int mid = l + r >> 1;
			// Left
			s[ls] += lazya[k];
			lazya[ls] += lazya[k];
			// Right
			s[rs] += lazya[k];
			lazya[rs] += lazya[k];
			// Reset
			lazya[k] = 0;
		}
		if(lazyc[k]) {
			int mid = l + r >> 1;
			// Left
			s[ls] = lazyc[k];
			lazyc[ls] = lazyc[k];
			lazya[ls] = 0;
			// Right
			s[rs] = lazyc[k];
			lazyc[rs] = lazyc[k];
			lazya[rs] = 0;
			// Reset
			lazyc[k] = lazya[k] = 0;
		}
	}
	void build(int k = 1, int l = 1, int r = n) {
		if(l == r) {
			s[k] = w[l];
			return ;
		}
		int mid = l + r >> 1;
		build(ls, l, mid);
		build(rs, mid + 1, r);
		pushup(k);
	}
	void update(int x, int y, int v, int k = 1, int l = 1, int r = n) {
		if(y < l || x > r) return ;
		if(x <= l && r <= y) {
			s[k] += v, lazya[k] += v;
			return ;
		}
		int mid = l + r >> 1;
		pushdown(k, l, r);
		if(x <= mid) update(x, y, v, ls, l, mid);
		if(y > mid) update(x, y, v, rs, mid + 1, r);
		pushup(k);
	}
	void updatec(int x, int y, int v, int k = 1, int l = 1, int r = n) {
		if(y < l || x > r) return ;
		if(x <= l && r <= y) {
			s[k] = v, lazyc[k] = v, lazya[k] = 0;
			return ;
		}
		int mid = l + r >> 1;
		pushdown(k, l, r);
		if(x <= mid) updatec(x, y, v, ls, l, mid);
		if(y > mid) updatec(x, y, v, rs, mid + 1, r);
		pushup(k);
	}
	int query(int x, int y, int k = 1, int l = 1, int r = n) {
		if(y < l || x > r) return 0;
		if(x <= l && r <= y) return s[k];
		int mid = l + r >> 1, ans = LLONG_MIN;
		pushdown(k, l, r);
		if(x <= mid) ans = std::max(ans, query(x, y, ls, l, mid));
		if(y > mid) ans = std::max(ans, query(x, y, rs, mid + 1, r));
		return ans;
	}
#undef ls
#undef rs
} st;
void add(int x, int y, int z) {
	 e[++tot].to = y;
	 e[tot].data = z;
	 e[tot].nex = head[x];
	 head[x] = tot;
}
void dfs1(int x, int f) {
	fa[x] = f;
	size[x] = 1;
	deep[x] = deep[f] + 1;
	erg(i, x) {
		int y = e[i].to, z = e[i].data;
		if(y == f) continue;
		dfs1(y, x);
		a[y] = z;
		size[x] += size[y];
		if(size[y] > size[son[x]]) son[x] = y;
	}
}
void dfs2(int x, int t) {
	top[x] = t;
	dfn[x] = ++cnt;
	w[cnt] = a[x];
	if(!son[x]) return ;
	dfs2(son[x], t);
	erg(i, x) {
		int y = e[i].to;
		if(y == fa[x] || y == son[x]) continue;
		dfs2(y, y);
	}
}
int LCA(int x, int y) {
	while(top[x] != top[y]) {
		if(deep[top[y]] > deep[top[x]]) std::swap(x, y);
		x = fa[top[x]];
	}
	return deep[x] < deep[y] ? x : y;
}
void update(int x, int y, int v) {
	int lca = LCA(x, y), t = st.query(dfn[lca], dfn[lca]);
	while(top[x] != top[y]) {
		if(deep[top[y]] > deep[top[x]]) std::swap(x, y);
		st.update(dfn[top[x]], dfn[x], v);
		x = fa[top[x]];
	}
	if(deep[x] > deep[y]) std::swap(x, y);
	st.update(dfn[x], dfn[y], v);
	st.updatec(dfn[lca], dfn[lca], t);
}
void updatec(int x, int y, int v) {
	int lca = LCA(x, y), t = st.query(dfn[lca], dfn[lca]);
	while(top[x] != top[y]) {
		if(deep[top[y]] > deep[top[x]]) std::swap(x, y);
		st.updatec(dfn[top[x]], dfn[x], v);
		x = fa[top[x]];
	}
	if(deep[x] > deep[y]) std::swap(x, y);
	st.updatec(dfn[x], dfn[y], v);
	st.updatec(dfn[lca], dfn[lca], t);
}
int query(int x, int y) {
	int ans = LLONG_MIN, lca = LCA(x, y), t = st.query(dfn[lca], dfn[lca]);
	st.updatec(dfn[lca], dfn[lca], LLONG_MIN);
	while(top[x] != top[y]) {
		if(deep[top[y]] > deep[top[x]]) std::swap(x, y);
		ans = std::max(ans, st.query(dfn[top[x]], dfn[x]));
		x = fa[top[x]];
	}
	if(deep[x] > deep[y]) std::swap(x, y);
	ans = std::max(ans, st.query(dfn[x], dfn[y]));
	st.updatec(dfn[lca], dfn[lca], t);
	return ans;
}
signed main() {
	n = read();
	rep(i, 1, n - 1) {
		int x = read(), y = read(), z = read();
		b[i] = (node) { x, y, z };
		add(x, y, z), add(y, x, z);
	}
	dfs1(1, 0);
	dfs2(1, 1);
	st.build();
	char op[10];
	while(scanf("%s", op), op[0] != 'S') {
		if(op[0] == 'C' && op[1] == 'h') {
			puts(op);
			int i = read(), v = read();
			st.updatec(dfn[deep[b[i].x] > deep[b[i].y] ? b[i].x : b[i].y], dfn[deep[b[i].x] > deep[b[i].y] ? b[i].x : b[i].y], v);
		} else if(op[0] == 'C' && op[1] == 'o') {
			puts(op);
			int x = read(), y = read(), v = read();
			// printl(query(1, 3));
			updatec(x, y, v);
			// printl(query(1, 3));
		} else if(op[0] == 'A') {
			puts(op);
			int x = read(), y = read(), v = read();
			// printl(query(1, 3));
			update(x, y, v);
			// printl(query(1, 3));
		} else {
			int x = read(), y = read();
			printl(query(x, y));
		}
		// printl(st.query(dfn[3], dfn[3]));
	}
	// print(query(1, 2));
	return 0;
}

自己把自己 hack 了

7
2 1 1
3 2 6
4 1 6
5 3 1
6 4 1
7 5 4
Cover 7 1 1
Add 1 7 6
Cover 2 6 1
Cover 5 2 7
Add 6 7 2
Max 1 3
Stop

正确输出 9,这份代码会输出 7

2024/11/8 19:46
加载中...