有无大佬帮忙看一下,全 MLE 了
查看原帖
有无大佬帮忙看一下,全 MLE 了
381053
CQ_Bab楼主2025/7/29 14:56
#include <bits/stdc++.h>
#include<ext/pb_ds/assoc_container.hpp>
#include<ext/pb_ds/tree_policy.hpp>
#include <ext/rope>
using namespace __gnu_pbds;
using namespace std;
#define pb push_back
#define rep(i,x,y) for(register int i=x;i<=y;i++)
#define rep1(i,x,y) for(register int i=x;i>=y;--i)
#define ll int
//#define int long long
#define fire signed
#define il inline
template<class T> il void print(T x) {
	if(x<0) printf("-"),x=-x;
	if (x > 9) print(x / 10);
	putchar(x % 10 + '0');
}
template<class T> il void in(T &x) {
    x = 0; char ch = getchar();
    int f = 1;
    while (ch < '0' || ch > '9') {if(ch=='-') f = -1; ch = getchar(); }
    while (ch >= '0' && ch <= '9') { x = (x << 3) + (x << 1) + (ch ^ 48); ch = getchar(); }
    x *= f;
}
int T=1;
const int mod=20160501;
const int N=1e5+10;
int dfn[N],val[N];
int a[N],n,m,dep[N];
vector<int>v[N];
int top[N],son[N],siz[N];
struct node{
	int l,r;
	ll tag;
	ll sa,sd,sdd;
	int v[2];
}tr[N<<6];
int rt[N],idx1;
ll d[N],dd[N];
int idx,fat[N];
void up(int x) {
	tr[x].sa=(tr[tr[x].l].sa+tr[tr[x].r].sa)%mod;
	tr[x].sd=(tr[tr[x].l].sd+tr[tr[x].r].sd)%mod;
	tr[x].sdd=(tr[tr[x].l].sdd+tr[tr[x].r].sdd)%mod;
}
int lca(int x,int y) {
	while(top[x]!=top[y]) {
		if(dep[top[x]]<dep[top[y]]) swap(x,y);
		x=fat[top[x]];
	}
	if(dep[x]>dep[y]) swap(x,y);
	return x;
}
void modify(int u,int l,int r,int l1,int r1,ll k) {
	if(l>=l1&&r<=r1) {
		tr[u].tag+=k;
		tr[u].sa+=1ll*k*(r-l+1)%mod;
		tr[u].sd+=1ll*k*(d[r]-d[l-1])%mod;
		tr[u].sdd+=1ll*k*(dd[r]-dd[l-1])%mod;
		return ;
	}
	int mid=l+r>>1;
	if(mid>=l1) {
		if(tr[u].v[0]) {
			tr[++idx]=tr[tr[u].l];
			tr[u].l=idx;
			tr[tr[u].l].v[0]=tr[tr[u].l].v[1]=1;
			tr[u].v[0]=false;
		}
		modify(tr[u].l,l,mid,l1,r1,k);
	}
	if(mid<r1) {
		if(tr[u].v[1]) {
			tr[++idx]=tr[tr[u].r];
			tr[u].r=idx;
			tr[tr[u].r].v[0]=tr[tr[u].r].v[1]=1;
			tr[u].v[1]=false;
		}
		modify(tr[u].r,mid+1,r,l1,r1,k);
	}
	up(u);
}
void dfs(int x,int fa) {
	dep[x]=dep[fa]+1;
	siz[x]=1;
	fat[x]=fa;
	for(auto to:v[x]) {
		if(to==fa) continue;
		dfs(to,x);
		siz[x]+=siz[to];
		if(siz[to]>siz[son[x]]) son[x]=to;
	}
}
void dfs1(int x,int head) {
	top[x]=head;
	dfn[x]=++idx1;
	val[idx1]=a[x];
	d[idx1]=dep[x];
	dd[idx1]=dep[x]*dep[x];
	if(!son[x]) return;
	dfs1(son[x],head);
	for(auto to:v[x]) {
		if(!dfn[to]) dfs1(to,to);
	}
}
void gai(int x,int y,ll k,int t) {
	while(top[x]!=top[y]) {
		if(dep[top[x]]<dep[top[y]]) swap(x,y);
		modify(rt[t],1,n,dfn[top[x]],dfn[x],k);
		x=fat[top[x]];
	}
	if(dfn[x]>dfn[y])swap(x,y);
	modify(rt[t],1,n,dfn[x],dfn[y],k);
}
ll Ans(int u,int l,int r,int k) {
	if(l==r) return tr[u].sa;
	ll res=tr[u].tag;
	int mid=l+r>>1;
	if(mid>=k) res+=Ans(tr[u].l,l,mid,k);
	else res+=Ans(tr[u].r,mid+1,r,k);
	return res;
}
ll Ansa(int u,int l,int r,int l1,int r1) {
	if(l>=l1&&r<=r1) {
		return tr[u].sa;
	}
	ll res=1ll*tr[u].tag*(min(r,r1)-max(l,l1)+1);
	int mid=l+r>>1;
	if(mid>=l1) res+=Ansa(tr[u].l,l,mid,l1,r1);
	if(mid<r1) res+=Ansa(tr[u].r,mid+1,r,l1,r1);
	return res;
}
ll Ansdd(int u,int l,int r,int l1,int r1) {
	if(l1>=l&&r<=r1) {
		return tr[u].sdd;
	}
	int res=1ll*tr[u].tag*(dd[min(r,r1)]-dd[max(l,l1)-1]);
	int mid=l+r>>1;
	if(mid>=l1) res+=Ansdd(tr[u].l,l,mid,l1,r1);
	if(mid<r1) res+=Ansdd(tr[u].r,mid+1,r,l1,r1);
	return res;
}
ll Ansd(int u,int l,int r,int l1,int r1) {
	int res=1ll*tr[u].tag*(d[min(r,r1)]-d[max(l,l1)-1]);
	if(l1>=l&&r<=r1) {
		return tr[u].sd;
	}
	int mid=l+r>>1;
	if(mid>=l1) res+=Ansd(tr[u].l,l,mid,l1,r1);
	if(mid<r1) res+=Ansd(tr[u].r,mid+1,r,l1,r1);
	return res;
}
int qmi(int x,int y) {
	int res=1;
	while(y) {
		if(y&1) res=1ll*res*x%mod;
		x=1ll*x*x%mod;
		y>>=1;
	}
	return res;
}
ll get(int x,int y,ll k,int t) {
	ll res=0;
	while(top[x]!=top[y]) {
		if(dep[top[x]]<dep[top[y]]) swap(x,y);
		res+=(Ansdd(rt[t],1,n,dfn[top[x]],dfn[x])+(2*k+1)%mod*Ansd(rt[t],1,n,dfn[top[x]],dfn[x])%mod+Ansa(rt[t],1,n,dfn[top[x]],dfn[x])*(k*k+k)%mod)%mod;
		res%=mod;
		x=fat[top[x]];
	}
	if(dfn[x]>dfn[y]) swap(x,y);
	res+=(Ansdd(rt[t],1,n,dfn[x],dfn[y])+(2*k+1)%mod*Ansd(rt[t],1,n,dfn[x],dfn[y])%mod+Ansa(rt[t],1,n,dfn[x],dfn[y])*(k*k+k)%mod)%mod;
	return res;
}
ll get1(int x,int y,int t) {
	ll res=0;
	while(top[x]!=top[y]) {
		if(dep[top[x]]<dep[top[y]]) swap(x,y);
		res+=Ansa(rt[t],1,n,dfn[top[x]],dfn[x])*dep[y]%mod*dep[y]%mod+Ansa(rt[t],1,n,dfn[top[x]],dfn[x])*dep[y]%mod-Ansd(rt[t],1,n,dfn[top[x]],dfn[x])*(2*dep[y]+1)%mod+Ansdd(rt[t],1,n,dfn[top[x]],dfn[x])%mod;
		res%=mod;
		x=fat[top[x]];
	}
	if(dfn[x]>dfn[y]) swap(x,y);
	res+=Ansa(rt[t],1,n,dfn[x],dfn[y])*dep[y]%mod*dep[y]%mod+Ansa(rt[t],1,n,dfn[x],dfn[y])*dep[y]%mod-Ansd(rt[t],1,n,dfn[x],dfn[y])*(2*dep[y]+1)%mod+Ansdd(rt[t],1,n,dfn[x],dfn[y])%mod;
	return res;
}
void build(int &u,int l,int r) {
	u=++idx;
	if(l==r) {
		tr[u].sa=val[l];
		tr[u].sd=(d[r]-d[l-1])*val[l];
		tr[u].sdd=(dd[r]-dd[l-1])*val[l];
		return;
	}
	int mid=l+r>>1;
	build(tr[u].l,l,mid);
	build(tr[u].r,mid+1,r);
	up(u);
}
int cnt[N];
void solve() {
	in(n),in(m);
	rep(i,1,n-1) {
		int x,y;
		in(x),in(y);
		v[x].pb(y);
		v[y].pb(x);
	} 
	rep(i,1,n) in(a[i]);
	dfs(1,0);
	dfs1(1,1);
	ll lst=0;
	rep(i,1,n) d[i]=d[i-1]+d[i],dd[i]+=dd[i-1];
	int tim=false;
	build(rt[0],1,n);
//	rep(i,1,n) modify(rt[0],1,n,i,i,val[i]);
	int nxt=-1;
	rep(io,1,m) {
		int opt;
		in(opt);
		if(opt==1) {
			int x,y,a;
			in(x),in(y),in(a);
			x=x^lst;
			y=y^lst;
			a=a^lst;
			++tim;
			rt[tim]=++idx;
			if(nxt!=-1) tr[rt[tim]]=tr[rt[nxt]],nxt=-1;
			else tr[rt[tim]]=tr[rt[tim-1]];
			tr[rt[tim]].v[0]=tr[rt[tim]].v[1]=1;
			gai(x,y,a,tim);
		}else if(opt==2) {
			int x,y;
			in(x),in(y);
			int res=0;
			x=x^lst;
			y=y^lst;
			int lc=lca(x,y);
			int now=tim;
			if(nxt!=-1) now=nxt;
			lst=(get(x,lc,dep[y]-dep[lc]*2,now)+get1(lc,y,now))%mod*((mod+1)/2)%mod;
			lst=((lst-Ans(rt[now],1,n,dfn[lc])*((dep[lc]+dep[y]-2*dep[lc])*(dep[y]-dep[lc]+1)/2)%mod)%mod+mod)%mod;
			printf("%lld\n",lst);
		}else {
			int x;
			in(x);
			x^=lst;
			nxt=x;
		}
	}
}
fire main() {
	while(T--) {
		solve();
	}
	return false;
}

思路应该是对的。

2025/7/29 14:56
加载中...