改了一上午了,头都裂开了,求大佬帮忙看看,树链剖分本身应该是没有错,线段树或者是查询,修改的部分不知道哪里错了
#include<iostream>
#include<cstring>
#include<vector>
using namespace std;
const int MAXN=1e5+10;
int N,M;
vector<int> adj[MAXN];
int w[MAXN];//点权
int L[MAXN<<1],R[MAXN<<1];
int dep[MAXN];//dep[x] -> x在树上的深度
int fa[MAXN];//fa[x] -> x在树上的父亲
int son[MAXN];//son[x] -> x的重儿子
int sz[MAXN];//sz[x] -> x的子树的节点个数
int dfn[MAXN],tt;//dfs序
int id[MAXN];//dfs序对应的节点编号 -> id[dfn[x]]=x
int top[MAXN];//top[x] -> x所在重链的顶部节点
int d[MAXN<<2];
int b[MAXN<<2];
void build(int s,int t,int p){
if(s==t){
L[p]=R[p]=w[id[s]];
d[p]=1;
return;
}
int m=(s+t)/2;
build(s,m,p*2);
build(m+1,t,p*2+1);
d[p]=(d[p*2]+d[p*2+1]);
L[p]=L[p*2];
R[p]=R[p*2+1];
if(L[p*2+1]==R[p*2]) d[p]--;
}
void update(int l,int r,int c,int s,int t,int p){
if(s>=l&&t<=r){
d[p]=1,b[p]=L[p]=R[p]=c;
return;
}
int m=(s+t)/2;
if(b[p]&&s!=t){
d[p*2]=d[p*2+1]=1;
b[p*2]=b[p*2+1]=b[p];
L[p*2]=L[p*2+1]=R[p*2]=R[p*2+1]=b[p];
b[p]=0;
}
if(l<=m) update(l,r,c,s,m,p*2);
if(r>m) update(l,r,c,m+1,t,p*2+1);
d[p]=(d[p*2]+d[p*2+1]);
R[p]=R[p*2+1];
L[p]=L[p*2];
if(L[p*2+1]==R[p*2]) d[p]--;
}
int query(int l,int r,int s,int t,int p){
if(s>=l&&t<=r) return d[p];
int m=(s+t)/2;
if(b[p]){
d[p*2]=d[p*2+1]=1;
b[p*2]=b[p*2+1]=b[p];
L[p*2]=L[p*2+1]=R[p*2]=R[p*2+1]=b[p];
b[p]=0;
}
int sum=0;
if(l<=m) sum+=query(l,r,s,m,p*2);
if(r>m) sum+=query(l,r,m+1,t,p*2+1);
if(L[p*2+1]==R[p*2]) sum--;
return sum;
}
//处理出fa[x] dep[x] sz[x] son[x]
void dfs1(int u){
sz[u]=1;
for(int k=0;k<adj[u].size();k++){
int v=adj[u][k];
if(v==fa[u]) continue;//防止走回头路
dep[v]=dep[u]+1;
fa[v]=u;
dfs1(v);
sz[u]+=sz[v];
if(sz[v]>sz[son[u]])//查找重儿子
son[u]=v;
}
}
void dfs2(int u,int x){//u为当前点,x为节点u所在重链的顶部节点
dfn[u]=++tt;
id[tt]=u;
top[u]=x;
if(!son[u]) return;//特判 -> 防止对不存在的重儿子进行dfs
dfs2(son[u],x);//优先对重儿子进行dfs,保证同一条重链的点dfs序连续
for(int k=0;k<adj[u].size();k++){
int v=adj[u][k];
if(v==fa[u]||v==son[u]) continue;
dfs2(v,v);
}
}
int qPoint(int p,int s,int t,int u){
if(s==t) return L[p];
if(b[p]){
d[p*2]=d[p*2+1]=1;
b[p*2]=b[p*2+1]=b[p];
L[p*2]=L[p*2+1]=R[p*2]=R[p*2+1]=b[p];
b[p]=0;
}
int m=(s+t)/2;
if(u<=m) return qPoint(p*2,s,m,u);
else return qPoint(p*2+1,m+1,t,u);
}
int qPath(int u,int v){//查询u到v最短路径上所有节点权值之和
int ans=0,x,y;
while(top[u]!=top[v]){
if(dep[top[u]]<dep[top[v]]) swap(u,v);
ans=(ans+query(dfn[top[u]],dfn[u],1,N,1));
x=qPoint(1,1,N,dfn[fa[top[u]]]);
y=qPoint(1,1,N,dfn[top[u]]);
if(x==y) ans--;
u=fa[top[u]];
}
if(dep[u]>dep[v]) swap(u,v);
ans=(ans+query(dfn[u],dfn[v],1,N,1));
return ans?ans:1;
}
void updPath(int u,int v,int dx){//u到v最短路径上所有节点权值加上dx
while(top[u]!=top[v]){
if(dep[top[u]]<dep[top[v]]) swap(u,v);
update(dfn[top[u]],dfn[u],dx,1,N,1);
u=fa[top[u]];
}
if(dep[u]>dep[v]) swap(u,v);
update(dfn[u],dfn[v],dx,1,N,1);
}
int main(){
scanf("%d%d",&N,&M);
for(int i=1;i<=N;i++)
scanf("%d",&w[i]);
int a,b;
for(int i=1;i<N;i++){
scanf("%d%d",&a,&b);
adj[a].push_back(b);
adj[b].push_back(a);
}
dfs1(1);
dfs2(1,1);
build(1,N,1);
int x,y,z;
char op[2];
while(M--){
scanf("%s",op);
if(op[0]=='Q'){
scanf("%d%d",&x,&y);
printf("%d\n",qPath(x,y));
}
else{
scanf("%d%d%d",&x,&y,&z);
updPath(x,y,z);
}
}
}