为什么会mle
查看原帖
为什么会mle
500542
lvyongfeng楼主2025/7/22 19:36
#include<bits/stdc++.h>
#define ls(a) (a << 1)
#define rs(a) (a << 1 | 1)
#define md(a, b) ((a + b) >> 1)
using namespace std;
typedef long long ll;
const int N = 3e4 + 10;
vector <int> vec[N << 1];
int siz[N], hv[N], fa[N], a[N], n;
void dfs1(int u, int f){
	siz[u] = 1;
	fa[u] = f;
	for(auto v : vec[u]){
		if(v ^ f){
			dfs1(v, u);
			siz[u] += siz[v];
			if(siz[v] > siz[hv[u]]){
				hv[u] = v;
			}
		}
	}
}
int dfn[N], dfr[N], top[N], dep[N], rd[N], cnt;
void dfs2(int u, int f){
	dep[u] = dep[f] + 1;
	top[u] = f;
	dfn[u] = dfr[u] = ++ cnt;
	if(hv[u]){
		dfs2(hv[u], f);
		dfr[u] = dfr[hv[u]];
		for(auto v : vec[u]){
			if(v != fa[u] && v != hv[u]){
				dfs2(v, v);
				dfr[u] = dfr[v];
			}
		}
	}
}
namespace Sgt{
	struct Node{
		int maxn, sum, t; 
	}tr[N * 4];
	void tag(int u, int t, int l, int r){
		tr[u].t += tr[u].t + t;
		tr[u].sum += t * (r - l + 1);
	}
	void push_up(int u){
		tr[u].maxn = max(tr[ls(u)].maxn ,tr[rs(u)].maxn);
		tr[u].sum = tr[ls(u)].sum + tr[rs(u)].sum;
	}
	void push_down(int u, int l, int r){
		if(tr[u].t){
			int mid = md(l, r);
			tag(ls(u), l, mid, tr[u].t);
			tag(rs(u), mid + 1, r, tr[u].t);
			tr[u].t = 0;
		}
	}
	void build(int l, int r, int u){
		tr[u].t = 0;
		if(l == r){
			tr[u].sum = a[u];
			return ;
		}
		int mid = md(l, r);
		build(l, mid, ls(u));
		build(mid + 1, r, rs(u));
		push_up(u);
	}
	int query_sum(int ql, int qr, int l, int r, int u){
		if(qr <= l && qr >= r){
			return tr[u].sum;
		}
		int ans = 0, mid = md(l, r);
		if(ql <= mid){
			ans += query_sum(ql, qr, l, mid, ls(u));
		}
		if(qr > mid){
			ans += query_sum(ql, qr, mid + 1, r, rs(u));
		}
		push_up(u);
		return ans;
	}
	int query_max(int ql, int qr, int l, int r, int u){
		if(qr <= l && qr >= r){
			return tr[u].maxn;
		}
		int ans = INT_MIN, mid = md(l, r);
		if(ql <= mid){
			ans = max(ans, query_max(ql, qr, l, mid, ls(u)));
		}
		if(qr > mid){
			ans = max(ans, query_max(ql, qr, mid + 1, r, rs(u)));
		}
		push_up(u);
		return ans;
	}
	void update(int qu, int l, int r, int u, int o){
		if(l == r){
			tr[u].maxn = tr[u].sum = o;
			return;
		}
		int mid = md(l, r);
		if(u >= qu){
			update(qu, l, mid, ls(u), o);
		}else{
			update(qu, mid + 1, r, rs(u), o);
		}
		push_up(u);
	}
	int qsum(int u, int v){
		int ans = 0;
		while(top[u] != top[v]){
			if(dep[top[u]] < dep[top[v]]){
				swap(u, v);
			}
			ans += query_sum(dfn[top[u]], dfn[u], 1, n, 1);
			u = fa[top[u]];
		}
		if(dep[u] < dep[v]){
			swap(u, v);
		}
		ans += query_sum(dfn[v], dfn[u], 1, n, 1);
		return ans;
	}
	int qmax(int u, int v){
		int ans = INT_MIN;
		while(top[u] != top[v]){
			if(dep[top[u]] < dep[top[v]]){
				swap(u, v);
			}
			ans = max(ans, query_max(dfn[top[u]], dfn[u], 1, n, 1));
			u = fa[top[u]];
		}
		if(dep[u] < dep[v]){
			swap(u, v);
		}
		ans = max(ans, query_max(dfn[v], dfn[u], 1, n, 1));
		return ans;
	}
}
int main(){
	scanf("%d", &n);
	for(int i = 1; i < n; i++){
		int x, y;
		scanf("%d%d", &x, &y);
		vec[x].push_back(y);
		vec[y].push_back(x);
	}
	for(int i = 1; i <= n; i++){
		scanf("%d", &a[i]);
	}
	dfs1(1, 0);
	dfs2(1, 1);
	Sgt::build(1, n, 1);
	int Q;
	scanf("%d", &Q);
	while(Q--){
		char s[20];
		int p, q;
		scanf("%s", s);
		scanf("%d%d", &p, &q);
		if(s[1] == 'H'){
			Sgt::update(dfn[p], 1, n, 1, q);
		}else if(s[1] == 'M'){
			int ans = Sgt::qmax(p, q);
			printf("%d\n", ans);
		}else if(s[1] == 'S'){
			int ans = Sgt::qsum(p, q);
			printf("%d\n", ans);
		}
	}
	return 0;
} 
2025/7/22 19:36
加载中...