树链剖分10pts求调
查看原帖
树链剖分10pts求调
750862
yxyaiunn楼主2024/10/30 13:57
#include <cstdio>
#include <vector>
using namespace std;
int n,P,a[100005],fa[100005],size[100005],dep[100005],son[100005],top[100005],id[100005],rev[100005];
vector<int> v[100005];
int cnt=0;
struct Node{
	long long sum=0;
	int tag=0;
}t[400005];
void dfs1(int now,int f){
	size[now]=1;
	for(int i=0;i<v[now].size();i++){
		int to=v[now][i];
		if(to==f) continue;
		dep[to]=dep[now]+1;
		fa[to]=now;
		dfs1(to,now);
		size[now]+=size[to];
		if(size[to]>size[son[now]]){
			son[now]=to;
		}
	}
	return;
}
void dfs2(int now,int fath){
	top[now]=fath;
	id[now]=++cnt;
	rev[cnt]=now;
	if(!son[now]) return;
	dfs2(son[now],now);
	for(int i=0;i<v[now].size();i++){
		int to=v[now][i];
		if(to!=fa[now]&&to!=son[now]){
			dfs2(to,now);
		}
	}
	return;
}
void build(int l,int r,int o){
	if(l==r){
		t[o].sum=a[rev[l]]%P;
		return;
	}
	int mid=l+r>>1;
	build(l,mid,o*2);
	build(mid+1,r,o*2+1);
	t[o].sum=(t[o*2].sum+t[o*2+1].sum)%P;
	return;
} 
void update1(int l,int r,int o,int now,int vv){
	if(l==r&&l==now){
		t[o].sum=(t[o].sum+vv)%P;
		return;
	}
	int mid=l+r>>1;
	if(mid>=now){
		update1(l,mid,o*2,now,vv);
	}else{
		update1(mid+1,r,o*2+1,now,vv);
	}
	t[o].sum=(t[o*2].sum+t[o*2+1].sum)%P;
	return;
}
long long query(int l,int r,int o,int ll,int rr){
	if(l<=ll&&rr<=r){
		t[o].sum=(t[o].sum+t[o].tag)%P;
		return t[o].sum;
	}
	int mid=l+r>>1;
	long long SUM=0;
	t[o*2].tag=(t[o*2].tag+t[o].tag)%P;
	t[o*2+1].tag=(t[o*2+1].tag+t[o].tag)%P;
	t[o].tag=0;
	if(ll<=mid){
		SUM=(query(l,mid,o*2,ll,rr)+SUM)%P;
	}
	if(mid<r){
		SUM=(query(mid+1,r,o*2+1,ll,rr)+SUM)%P;
	}
	t[o].sum=(t[o*2].sum+t[o*2+1].sum)%P;
	return SUM;
}
void update2(int l,int r,int o,int ll,int rr,int val){
	if(l<=ll&&rr<=r){
		t[o].tag=(t[o].tag+val)%P;
		return;
	}
	int mid=l+r>>1;
	if(ll<=mid){
		update2(l,mid,o*2,ll,rr,val);
	}
	if(mid<r){
		update2(mid+1,r,o*2+1,ll,rr,val);
	}
	return;
}
long long ask(int uu,int vv){
	long long SUM=0;
	while(top[uu]!=top[vv]){
		if(dep[top[uu]]<dep[top[vv]]){
			swap(uu,vv);
		}
		SUM=(query(1,n,1,id[top[uu]],id[uu])+SUM)%P;
		uu=fa[top[uu]];
	}
	if(dep[uu]>dep[vv]){
		swap(uu,vv);
	}
	SUM=(query(1,n,1,id[uu],id[vv])+SUM)%P;
	return SUM;
}
void update3(int uu,int vv,int val){
	while(top[uu]!=top[vv]){
		if(dep[top[uu]]<dep[top[vv]]){
			swap(uu,vv);
		}
		update2(1,n,1,id[top[uu]],id[uu],val);
		uu=fa[top[uu]];
	}
	if(dep[uu]>dep[vv]){
		swap(uu,vv);
	}
	update2(1,n,1,id[uu],id[vv],val);
}
int main(){
	int m,R;
	scanf("%d%d%d%d",&n,&m,&R,&P);
	for(int i=1;i<=n;i++){
		scanf("%d",&a[i]);
	}
	for(int i=1;i<n;i++){
		int uu,vv;
		scanf("%d%d",&uu,&vv);
		v[uu].push_back(vv);
		v[vv].push_back(uu);
	}
	dfs1(R,0);
	dfs2(R,0);
	build(1,n,1);
	for(int i=1;i<=m;i++){
		int op;
		scanf("%d",&op);
		if(op==1){
			int x,y,z;
			scanf("%d%d%d",&x,&y,&z);
			update3(x,y,z);
		}else if(op==2){
			int x,y;
			scanf("%d%d",&x,&y);
			if(id[x]>id[y]){
				swap(x,y);
			}
			query(1,n,1,id[x],id[y]);
			printf("%lld\n",ask(x,y));
		}else if(op==3){
			int x,z;
			scanf("%d%d",&x,&z);
			update2(1,n,1,id[x],id[x]+size[x]-1,z);
		}else{
			int x;
			scanf("%d",&x);
			query(1,n,1,id[x],id[x]+size[x]-1);
			printf("%lld\n",query(1,n,1,id[x],id[x]+size[x]-1));
		}
	}
	return 0;
} 
2024/10/30 13:57
加载中...