WA13pts求调
查看原帖
WA13pts求调
1623592
florrer_cy楼主2024/12/21 08:51
#include<bits/stdc++.h>
using namespace std;
#define N 100005
int add[N<<2],sum[N<<2],s[N<<2],t[N<<2];
int dep[N],fa[N],size[N],son[N],dfn[N],a[N],top[N],id[N];
int tot,n,m,r,p;
vector<int>e[N];
void dfs1(int,int);
void dfs2(int,int);
void build(int,int,int);
void pushdown(int);
int getsum1(int,int,int);
int getsum2(int,int);
void update_add1(int,int,int,int);
void update_add2(int,int,int);
int main(){
	ios::sync_with_stdio(0);
	cin>>n>>m>>r>>p;
	for(int i=1;i<=n;i++){
		cin>>a[i];
		a[i]%=p;
	}
	for(int i=1,u,v;i<n;i++){
		cin>>u>>v;
		e[u].push_back(v);
		e[v].push_back(u);
	}
	dfs1(r,0);
	dfs2(r,r);
	build(1,1,n);
	//cout<<size[4]<<endl;
	for(int type,x,y,z;m--;){
		cin>>type>>x;
		switch(type){
			case 1:
				cin>>y>>z;
				update_add2(x,y,z%p);
				break;
			case 2:
				cin>>y;
				cout<<getsum2(x,y)<<endl;
				break;
			case 3:
				cin>>z;
				update_add1(1,dfn[x],dfn[x]+size[x]-1,z%p);
				break;
			case 4:
				cout<<getsum1(1,dfn[x],dfn[x]+size[x]-1)<<endl;
		}
	}
	return 0;
}
void dfs1(int u,int f){
	fa[u]=f;
	dep[u]=dep[f]+1;
	size[u]=1;
	for(auto v:e[u]){
		if(v==f) continue;
		dfs1(v,u);
		size[u]+=size[v];
		if(size[v]>size[son[u]]) son[u]=v;
	}
}
void dfs2(int u,int t){
	dfn[u]=++tot;
	id[dfn[u]]=u;
	top[u]=t;
	if(son[u]==0) return;
	dfs2(son[u],t);
	for(auto v:e[u]){
		if(v==fa[u]||v==son[u]) continue;
		dfs2(v,v);
	}
}
void build(int tid,int l,int r){
	s[tid]=l;t[tid]=r;
	if(l==r){s[tid]=t[tid]=l;sum[tid]=a[id[l]]%p;return;}
	int mid=(l+r)>>1;
	build(tid<<1,l,mid);
	build((tid<<1)|1,mid+1,r);
	sum[tid]=(1ll*sum[tid<<1]+sum[(tid<<1)|1])%p;
}
void pushdown(int tid){
	if(add[tid]==0) return;
	sum[tid<<1]=(sum[tid<<1]+1ll*(t[tid<<1]-s[tid<<1]+1)*add[tid]%p)%p;
	sum[(tid<<1)|1]=(sum[(tid<<1)|1]+1ll*(t[(tid<<1)|1]-s[(tid<<1)|1]+1)*add[tid]%p)%p;
	add[tid<<1]=(1ll*add[tid<<1]+add[tid])%p;
	add[(tid<<1)|1]=(1ll*add[(tid<<1)|1]+add[tid])%p;
	add[tid]=0;
}
int getsum1(int tid,int l,int r){
	if(s[tid]>=l&&t[tid]<=r) return sum[tid];
	int mid=(s[tid]+t[tid])>>1,ret=0;
	pushdown(tid);
	if(l<=mid) ret=getsum1(tid<<1,l,mid);
	if(r>mid) ret=(ret+1ll*getsum1((tid<<1)|1,mid+1,r))%p;
	return ret;
}
int getsum2(int u,int v){
    int ret=0;
    while(top[u]!=top[v]){
        if(dep[top[u]]<dep[top[v]]) swap(u,v);
        ret=(ret+1ll*getsum1(1,dfn[top[u]],dfn[u]))%p;
        //cerr<<top[u]<<" "<<u<<" ";
		//cerr<<getsum1(1,dfn[top[u]],dfn[u])<<endl;
        u=fa[top[u]];
    }
    if(dep[u]>dep[v]) swap(u,v);
    ret=(ret+1ll*getsum1(1,dfn[u],dfn[v]))%p;
   // cerr<<u<<" "<<v<<" ";
   // cerr<<getsum1(1,dfn[v],dfn[u])<<endl;
    return ret;
}
void update_add1(int tid,int l,int r,int k){
	if(s[tid]>=l&&t[tid]<=r){
		sum[tid]=(sum[tid]+1ll*(t[tid]-s[tid]+1)*k%p)%p;
		add[tid]=(add[tid]+1ll*k)%p;
		return;
	}
//	//cout<<s[tid]<<" "<<t[tid]<<endl;
	int mid=(s[tid]+t[tid])>>1;
	pushdown(tid);
	if(l<=mid) update_add1(tid<<1,l,r,k);
	if(r>mid) update_add1((tid<<1)|1,l,r,k);
}
void update_add2(int u,int v,int k){
    while(top[u]!=top[v]){
        if(dep[top[u]]<dep[top[v]]) swap(u,v);
        update_add1(1,dfn[top[u]],dfn[u],k);
        u=fa[top[u]];
    }
    if(dep[u]>dep[v])    swap(u,v);
    update_add1(1,dfn[u],dfn[v],k);
}
2024/12/21 08:51
加载中...