30分求调 其他tle
查看原帖
30分求调 其他tle
1040658
Needna楼主2024/10/7 14:37
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=1e6+10;
long long tot,n,q,w[N],siz[N],hson[N],fa[N],dep[N];
long long pred[N],num[N],sum[N],maxn[N],top[N];
vector<int> e[N];
void dfs1(int u){
	siz[u]=1;
	for(int v:e[u]){
		if(v==fa[u]) continue;
		dep[v]=dep[u]+1;fa[v]=u;
		dfs1(v);
		siz[u]+=siz[v];
		if(siz[hson[u]]<siz[v]) hson[u]=v;
	}
}
void dfs2(int u,int tp){
	top[u]=tp;
	num[u]=++tot;
	pred[tot]=u;
	if(hson[u]) dfs2(hson[u],tp);
	for(int v:e[u]){
		if(v!=hson[u]&&v!=fa[u]) dfs2(v,v);
	} 
}
void build(int p,int l,int r){
	if(l==r){sum[p]=maxn[p]=w[pred[l]];return;}
	int mid=l+r>>1;
	build(p<<1,l,mid);build(p<<1|1,mid+1,r);
	sum[p]=sum[p<<1]+sum[p<<1|1];maxn[p]=max(maxn[p<<1],maxn[p<<1|1]);
}
void update(int p,int x,int y,int l,int k){
	if(x==y){sum[p]=k;maxn[p]=k;return;}
	int mid=x+y>>1;
	if(l<=mid) update(p<<1,x,mid,l,k);
	else update(p<<1|1,mid+1,y,l,k);
	sum[p]=sum[p<<1]+sum[p<<1|1];maxn[p]=max(maxn[p<<1],maxn[p<<1|1]);
}
int ask_mx(int p,int x,int y,int l,int r){
	if(l<=x&&y<=r){return maxn[p];}
	int mid=x+y>>1;
	if(l>mid) return ask_mx(p<<1|1,mid+1,y,l,r);
	if(r<=mid) return ask_mx(p<<1,x,mid,l,r);
	return max(ask_mx(p<<1,x,mid,l,r),ask_mx(p<<1|1,mid+1,y,l,r));
}
int get_mx(int x,int y){
	int ans=-INT_MAX;
	while(top[x]!=top[y]){
		if(dep[top[x]]<dep[top[y]]) swap(x,y);
		ans=max(ans,ask_mx(1,1,tot,top[x],num[x]));
		x=fa[top[x]];
	}
	if(num[x]>num[y]) swap(x,y);
	ans=max(ans,ask_mx(1,1,tot,num[x],num[y]));
	return ans;
}
int ask_sum(int p,int x,int y,int l,int r){
	if(l<=x&&y<=r){return sum[p];}
	int mid=x+y>>1;
	if(l>mid) return ask_sum(p<<1|1,mid+1,y,l,r);
	if(r<=mid) return ask_sum(p<<1,x,mid,l,r);
	return ask_sum(p<<1|1,mid+1,y,l,r)+ask_sum(p<<1,x,mid,l,r);
}
int get_sum(int x,int y){
	int ans=0;
	while(top[x]!=top[y]){
		if(dep[top[x]]<dep[top[y]]) swap(x,y);
		ans=ans+ask_sum(1,1,tot,top[x],num[x]);
		x=fa[top[x]];
	}
	if(num[x]>num[y]) swap(x,y);
	ans=ans+ask_sum(1,1,tot,num[x],num[y]);
	return ans;
}
signed main(){
    cin>>n;
    for(int i=1;i<n;i++){
    	int x,y;cin>>x>>y;
    	e[x].push_back(y),e[y].push_back(x);
	}
	dfs1(1);
	dfs2(1,1);
	for(int i=1;i<=n;i++){
    	cin>>w[i];
	}
	build(1,1,tot);
	char ch[10];cin>>q;
    for(int i=1;i<=q;i++){
        int x,y;
        scanf("%s%lld%lld",ch,&x,&y);
        if(ch[0]=='C'){w[x]=y;update(1,1,tot,num[x],y);}
        else{
            if(ch[1]=='M') printf("%lld\n",get_mx(x,y));
            else printf("%lld\n",get_sum(x,y));
        }
    }
	return 0;
}
2024/10/7 14:37
加载中...