代码如下
#include<bits/stdc++.h>
using namespace std;
inline int ls(int x){return x<<1;}
inline int rs(int x){return x<<1|1;}
const int maxn=2e5+10;
struct edge
{
int next,v;
}e[2*maxn];
struct point
{
int l,r,maxx,sum;
}tree[4*maxn];
int rt,n,q,r,a[maxn],cnt,head[maxn*2],f[maxn],d[maxn],size[maxn],son[maxn],rk[maxn],top[maxn],id[maxn],num;
void add(int u,int v)
{
num++;
e[num].v=v;
e[num].next=head[u];
head[u]=num;
}
void dfs1(int u,int fa,int depth)
{
f[u]=fa;
d[u]=depth;
size[u]=1;
for(int i=head[u];i;i=e[i].next)
{
int v=e[i].v;
if(v==fa)continue;
dfs1(v,u,depth+1);
size[u]+=size[v];
if(size[u]>size[son[u]] || !son[u])
{
son[u]=v;
}
}
}
void dfs2(int u,int t)
{
top[u]=t;
id[u]=++cnt;
rk[cnt]=u;
if(!son[u])return;
dfs2(son[u],t);
for(int i=head[u];i;i=e[i].next)
{
int v=e[i].v;
if(v!=son[u]&&v!=f[u])
{
dfs2(v,v);
}
}
}
void build(int p,int l,int r)
{
int mid=(l+r)>>1;
tree[p].l=l;
tree[p].r=r;
if(l==r)
{
tree[p].maxx=tree[p].sum=a[rk[mid]];
return;
}
build(ls(p),l,mid);
build(rs(p),mid+1,r);
tree[p].maxx=max(tree[ls(p)].maxx,tree[rs(p)].maxx);
tree[p].sum=tree[ls(p)].sum+tree[rs(p)].sum;
}
int query_max(int p,int l,int r)
{
if(tree[p].l>=l&&tree[p].r<=r)
{
return tree[p].maxx;
}
int mid=(tree[p].l+tree[p].r)>>1;
int s=-0x7fffff;
if(l<=mid)
{
s=max(query_max(ls(p),l,r),s);
}
if(r>mid)
{
s=max(query_max(rs(p),l,r),s);
}
return s;
}
int query_sum(int p,int l,int r)
{
if(tree[p].l>=l&&tree[p].r<=r)
{
return tree[p].sum;
}
int mid=(tree[p].l+tree[p].r)>>1;
int s=0;
if(l<=mid)
{
s+=query_sum(ls(p),l,r);
}
if(r>mid)
{
s+=query_sum(rs(p),l,r);
}
return s;
}
void change(int p,int l,int x)
{
if(tree[p].l==tree[p].r)
{
tree[p].sum=x;
tree[p].maxx=x;
return;
}
int mid=(tree[p].l+tree[p].r)>>1;
if(l<=mid)
{
change(ls(p),l,x);
}
if(l>mid)
{
change(rs(p),l,x);
}
tree[p].sum=tree[ls(p)].sum+tree[rs(p)].sum;
tree[p].maxx=max(tree[ls(p)].maxx,tree[rs(p)].maxx);
}
int find_sum(int x,int y)
{
int ret=0,fx=top[x],fy=top[y];
while(fx!=fy)
{
if(d[fx]>=d[fy])
{
ret+=query_sum(1,id[fx],id[x]);
x=f[fx];
fx=top[x];
}else
{
ret+=query_sum(1,id[fy],id[y]);
y=f[fy];
fy=top[y];
}
}
if(id[x]<=id[y])ret+=query_sum(1,id[x],id[y]);
else ret+=query_sum(1,id[y],id[x]);
return ret;
}
int find_max(int x,int y)
{
int ret=-21474836,fx=top[x],fy=top[y];
while(fx!=fy)
{
if(d[fx]>=d[fy])
{
ret=max(query_max(1,id[fx],id[x]),ret);
x=f[fx];
fx=top[x];
}else
{
ret=max(query_max(1,id[fy],id[y]),ret);
y=f[fy];
fy=top[y];
}
}
if(id[x]<=id[y])ret=max(query_max(1,id[x],id[y]),ret);
else ret=max(query_max(1,id[y],id[x]),ret);
return ret;
}
int main()
{
cin>>n;
for(int i=1;i<=n-1;i++)
{
int u,v;
scanf("%d%d",&u,&v);
add(u,v);
add(v,u);
}
for(int i=1;i<=n;i++)
{
cin>>a[i];
}
dfs1(1,0,1);
dfs2(1,1);
build(1,1,n);
cin>>q;
for(int i=1;i<=q;i++)
{
string s;
cin>>s;
if(s=="QSUM")
{
int x,y;
scanf("%d%d",&x,&y);
printf("%d\n",find_sum(x,y));
}else if(s=="QMAX")
{
int x,y;
scanf("%d%d",&x,&y);
printf("%d\n",find_max(x,y));
}else
{
int x,t;
scanf("%d%d",&x,&t);
change(1,id[x],t);
}
}
return 0;
}