求助,悬棺
查看原帖
求助,悬棺
1262406
yhy2024楼主2024/12/20 13:06

rt,55分wa

#include<bits/stdc++.h>
#define N 100005
using namespace std;
int lazy[N],c[N],n,m,x,y,z,cnt;
int he[N],d[N],seg[N],rev[N],top[N],fa[N],son[N],size[N]; 
char op;
struct{
	int to,nxt;
}e[N*2];
void Add(int x,int y){
	e[++cnt].to=y;
	e[cnt].nxt=he[x];
	he[x]=cnt;
}
struct node{
	int sum,cl,cr;
	node operator + (const node &x) const{
		 node ans={0,0,0};
		 ans.cl=cl;
		 ans.cr=x.cr;
		 ans.sum=sum+x.sum-(cr==x.cl);
		 return ans;
	}
}tr[N*4];
void dfs1(int x,int f){
	size[x]=1;
	for(int i=he[x];i;i=e[i].nxt){
		int v=e[i].to;	
		if(v!=f){
			d[v]=d[x]+1,fa[v]=x;
			dfs1(v,x);
			size[x]+=size[v];
			if(size[v]>size[son[x]]){
				son[x]=v;
			}
		}
	}
}
void dfs2(int x){
	if(!son[x])return;
	seg[son[x]]=++seg[0],rev[seg[0]]=son[x];
	top[son[x]]=top[x],dfs2(son[x]);
	for(int i=he[x];i;i=e[i].nxt){
		int v=e[i].to;	
		if(!top[v]){
		  	seg[v]=++seg[0],rev[seg[0]]=v;
        	top[v]=v,dfs2(v);
		}
	}
}
void pushup(int k){
	tr[k]=tr[k<<1]+tr[k<<1|1];
}
void add(int k,int v){
	tr[k].sum=1;
	tr[k].cr=tr[k].cl=v;
	lazy[k]=v;
}
void pushdown(int k){
	if(lazy[k]){
		add(k<<1,lazy[k]);
		add(k<<1|1,lazy[k]);
		lazy[k]=0;
		pushup(k);
	}
}
void build(int k,int l,int r){
	if(l==r){
		tr[k].sum=1;
		tr[k].cr=tr[k].cl=c[rev[l]];
		return;
	}
	int mid=l+r>>1;
	build(k<<1,l,mid);
	build(k<<1|1,mid+1,r);
	pushup(k); 
}
void modify(int k,int l,int r,int x,int y,int v){
	if(y<l||x>r) return;
	if(x<=l&&r<=y){
		add(k,v);	
		return;
	}
	int mid=l+r>>1;
	pushdown(k);
	modify(k<<1,l,mid,x,y,v);
	modify(k<<1|1,mid+1,r,x,y,v);
	pushup(k);
}
void debug(node x){
	printf("l=%d r=%d sum=%d ",x.cl,x.cr,x.sum);
}
node query(int k,int l,int r,int x,int y){
	if(x<=l&&r<=y) return tr[k];
	int mid=l+r>>1;
	node ans={0,0,0};
	pushdown(k);
	if(x<=mid) ans=query(k<<1,l,mid,x,y);
	if(mid<y){
		if(ans.sum) ans=ans+query(k<<1|1,mid+1,r,x,y);		
		else ans=query(k<<1|1,mid+1,r,x,y);
	}
	pushup(k);
	return ans;
}
int QSUM(int x,int y){
	int fx=top[x],fy=top[y],ans=0,ans1=-1,ans2=-1;
	node xx={0,0,0};
	while(fx!=fy){
		if(d[fx]<d[fy]) swap(x,y),swap(fx,fy),swap(ans1,ans2);
		xx=query(1,1,seg[0],seg[fx],seg[x]);
		ans+=xx.sum;
		if(ans1==xx.cr) ans--;
		ans1=xx.cl;
		x=fa[fx],fx=top[x]; 
	} 
	if(d[x]<d[y]) swap(x,y),swap(ans1,ans2);
//	printf("%d %d %d %d\n",ans1,ans2,d[x],d[y]);
	xx=query(1,1,seg[0],seg[y],seg[x]);
	ans+=xx.sum;
	if(xx.cr==ans1) ans--;
	if(xx.cl==ans2) ans--;
	return ans; 
} 
void Modify(int x,int y,int z){
	int fx=top[x],fy=top[y];
	while(fx!=fy){
		if(d[fx]<d[fy]) swap(x,y),swap(fx,fy);
		modify(1,1,seg[0],seg[fx],seg[x],z);
		x=fa[fx],fx=top[x]; 
	}
	if(d[x]>d[y]) swap(x,y);
	modify(1,1,seg[0],seg[x],seg[y],z);
} 
signed main(){
   // freopen("data.in","r",stdin);
	cin>>n>>m;
	for(int i=1;i<=n;i++){
		cin>>c[i];
	}
	for(int i=1;i<=n-1;i++){
		cin>>x>>y;
		Add(x,y),Add(y,x);
	} 
	top[1]=seg[0]=seg[1]=d[1]=rev[1]=1;
	dfs1(1,0),dfs2(1),build(1,1,n);
	for(int i=1;i<=m;i++){
		cin>>op>>x>>y;
		if(op=='C'){
			cin>>z;
			Modify(x,y,z);
		} 
		else{
			cout<<QSUM(x,y)<<endl;
		}
	}
	return 0;
}
2024/12/20 13:06
加载中...