30pts求助
查看原帖
30pts求助
1393897
x11223344楼主2025/1/12 10:55
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
const int N = 1e5 + 10;
struct Node { int l, r; };
int n, m, cnt;
int a[N], fa[N], dep[N], dfn[N], rdfn[N], top[N], sz[N], maxx[N];
vector<int> p[N];
vector<Node> r[2];
void dfs1(int x, int f)
{
	fa[x] = f; sz[x] = 1; dep[x] = dep[f] + 1;
	for (int i = 0; i < p[x].size(); i++)
		if (p[x][i] != f)
		{
			dfs1(p[x][i], x);
			sz[x] += sz[p[x][i]];
			if (sz[p[x][i]] > sz[maxx[x]]) maxx[x] = p[x][i];
		} 
}
void dfs2(int x, int t)
{
	dfn[x] = ++cnt; rdfn[cnt] = x; top[x] = t;
	if (maxx[x] != 0) dfs2(maxx[x], t);
	for (int i = 0; i < p[x].size(); i++)
		if (p[x][i] != fa[x] && p[x][i] != maxx[x]) dfs2(p[x][i], p[x][i]);
}
struct seg { int l, r, sum, tag; } s[N << 2];
seg pushup(seg s1, seg s2) 
{ 
	if (s1.sum == 0) return s2;
	if (s2.sum == 0) return s1;
	return seg({s1.l, s2.r, s1.sum + s2.sum - (s1.r == s2.l), 0});
}
void build(int l, int r, int u)
{
	if (l == r) { s[u].l = a[rdfn[l]], s[u].r = a[rdfn[l]], s[u].sum = 1; return ; }
	int mid = (l + r) / 2, ls = u << 1, rs = u << 1 | 1;
	build(l, mid, ls);
	build(mid + 1, r, rs);
	s[u] = pushup(s[ls], s[rs]);
}
void maketag(int u, int l, int r, int x)
{
	s[u].l = x; s[u].r = x; s[u].sum = 1;
	s[u].tag = x;
}
void pushdown(int u, int l, int r)
{
	if (s[u].tag == 0) return ;
	int mid = (l + r) / 2, ls = u << 1, rs = u << 1 | 1;
	maketag(ls, l, mid, s[u].tag);
	maketag(rs, mid + 1, r, s[u].tag);
	s[u].tag = 0;
}
seg query(int l, int r, int L, int R, int u)
{
	if (l <= L && R <= r) return s[u];
	else if (l <= R && L <= r) 
	{
		int mid = (L + R) / 2, ls = u << 1, rs = u << 1 | 1;
		pushdown(u, L, R);
		return pushup(query(l, r, L, mid, ls), query(l, r, mid + 1, R, rs));
	} 
	else return seg({0, 0, 0, 0});
}
void update(int l, int r, int L, int R, int u, int x)
{
	if (l <= L && R <= r) maketag(u, L, R, x);
	else if (l <= R && L <= r)
	{
		int mid = (L + R) / 2, ls = u << 1, rs = u << 1 | 1;
		pushdown(u, L, R);
		update(l, r, L, mid, ls, x); update(l, r, mid + 1, R, rs, x);
		s[u] = pushup(s[ls], s[rs]); 
	}
}
int qry(int u, int v)
{
	r[0] = vector<Node>(); r[1] = vector<Node>();
	int x = 1, y = 0;
	while (top[u] != top[v])
	{
		if (dep[top[u]] < dep[top[v]]) swap(u, v), swap(x, y);
		r[x].push_back(Node({dfn[top[u]], dfn[u]}));
		u = fa[top[u]];
	} 
	if (dep[u] > dep[v]) swap(u, v);
	else swap(x, y);
	r[x].push_back(Node({dfn[u], dfn[v]}));
	reverse(r[0].begin(), r[0].end());
	seg ans = {0, 0, 0, 0};
	for (int i = 1; i >= 0; i--)
	{
		for (int j = 0; j < r[i].size(); j++) ans = pushup(ans, query(r[i][j].l, r[i][j].r, 1, n, 1));
		swap(ans.l, ans.r);
	}
	return ans.sum;
}
void upd(int u, int v, int x)
{
	while (top[u] != top[v])
	{
		if (dep[top[u]] < dep[top[v]]) swap(u, v);
		update(dfn[top[u]], dfn[u], 1, n, 1, x);
		u = fa[top[u]];
	}
	update(min(dfn[u], dfn[v]), max(dfn[u], dfn[v]), 1, n, 1, x);
}
int main()
{
	cin >> n >> m;
	for (int i = 1; i <= n; i++) cin >> a[i];
	for (int i = 1; i < n; i++)
	{
		int u, v;
		cin >> u >> v;
		p[u].push_back(v);
		p[v].push_back(u);
	}
	dfs1(1, 0);
	dfs2(1, 1);
	build(1, n, 1);
	while (m--)
	{
		char op;
		int u, v, x;
		cin >> op >> u >> v;
		if (op == 'C')
		{
			cin >> x;
			upd(u, v, x);
		}
		else cout << qry(u, v) << endl;
	}
	return 0;
}  

#1#2#3#4#5#6 和 hack 通过了,其他都WA求调。

2025/1/12 10:55
加载中...