RT,全RE了,原因是 ins 时一直在跳,却又查不出哪里错了,求调感谢!
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
inline int read() {
int x=0,f=1;char ch=getchar();
while (ch<'0'||ch>'9'){if (ch=='-') f=-1;ch=getchar();}
while (ch>='0'&& ch<='9'){x=x*10+ch-48;ch=getchar();}
return x*f;
}
const int N = 5e5 + 10, M = 19, K = 5e6 + 10, inf = 1e9;
int n,m,a[N],dfn[N],dep[N],cnt,mn[M][N],fa[N];
vector<int> e[N];
inline void add(int u,int v){
e[u].push_back(v);
}
inline int calc(int x,int y){
return dfn[x] < dfn[y] ? x : y;
}
void dfs(int u,int fa){
dep[u]=dep[fa]+1;
mn[0][dfn[u]=++cnt]=fa;
for (auto v:e[u]){
if (v==fa) continue;
dfs(v,u);
}
}
int siz[N],vis[N],mx[N],rt;
int calcrt(int u,int fa,int tot){
siz[u]=1,mx[u]=0;
for (auto v:e[u]){
if (v==fa || vis[v]) continue;
calcrt(v,u,tot);
siz[u]+=siz[v];
mx[u]=max(mx[u],siz[v]);
}
mx[u]=max(mx[u],tot-siz[u]);
if (mx[u]<mx[rt]) rt=u;
}
void build(int u,int tot){
vis[u]=1;
for (auto v:e[u]){
if (vis[v]) continue;
rt=0;
int p=(siz[v]<siz[u]) ? siz[u] : (tot-siz[u]);
calcrt(v,u,p);
fa[rt]=u;
build(rt,p);
}
}
void init(){
for (int i=1;i<=18;i++){
for (int j=1;j<=n;j++){
mn[i][j]=calc(mn[i-1][j],mn[i-1][j+(1<<i-1)]);
}
}
}
inline int lca(int u,int v){
if(u==v) return u;
if((u=dfn[u])>(v=dfn[v])) swap(u,v);
int d=__lg(v-u++);
return calc(mn[d][u],mn[d][v-(1<<d)+1]);
}
inline int dis(int x,int y){
return dep[x]+dep[y]-(dep[lca(x,y)]<<1);
}
struct segment{
int rt[N],lc[K],rc[K],w[K];
inline void pushup(int p){
w[p]=w[lc[p]]+w[rc[p]];
}
void ins(int &p,int l,int r,int v,int k){
if(!p) p=++cnt;
if(l==r){w[p]+=k;return ;}
int mid=l+r>>1;
if(v<=mid) ins(lc[p],l,mid,v,k);
else ins(rc[p],mid+1,r,v,k);
pushup(p);
}
int qry(int p,int cl,int cr,int l,int r){
if(!p || cr<l || cl>r) return 0;
if(l<=cl && cr<=r) return w[p];
int mid=cl+cr>>1;
return qry(lc[p],cl,mid,l,r)+qry(rc[p],mid+1,cr,l,r);
}
}t1,t2;
void ins(int u,int w){
int v=u;
while(u){
t1.ins(t1.rt[u],0,n,dis(u,v),w);
if(fa[u]) t2.ins(t2.rt[u],0,n,dis(u,fa[u]),w);
u=fa[u];
}
}
int qry(int u,int k){
int v=u,son=0,res=0;
while(u){
if(dis(u,v)>k){
son=u;
u=fa[u];
continue;
}
res+=t1.qry(t1.rt[u],0,n,0,k-dis(u,v));
if(son) res-=t2.qry(t2.rt[son],0,n,0,k-dis(u,v));
son=u;
u=fa[u];
}
return res;
}
int main(){
n=read(),m=read();
for(int i=1;i<=n;i++){
a[i]=read();
}
for(int i=1;i<=n-1;i++){
int u=read(),v=read();
add(u,v),add(v,u);
}
dfs(1,0);
init();
rt=0;
mx[rt]=inf;
calcrt(1,0,n);
build(rt,n);
for(int i=1;i<=n;i++){
ins(i,a[i]);
}
int x=0;
while(m--){
int opr=read(),u=read()^x,k=read()^x;
if(opr==0){
printf("%d\n",x=qry(u,k));
}
else{
ins(x,k-a[u]);
a[u]=k;
}
}
return 0;
}