T飞力,救救孩子吧
查看原帖
T飞力,救救孩子吧
1042152
Resss楼主2025/8/1 22:08

除了第三个点AC以为全部TLE,第四个点本地实测35s(没打错单位),请大佬们帮看看

#include <bits/stdc++.h>
#define ll long long
using namespace std;
const ll N=2e6;
struct node{
	ll l,r,f,t,ans;
}a[N+5];
struct node1{
	ll x,y;
}b[N+5];
ll n,m,q;
ll c[N+5],v[N+5],w[N+5];
ll cnta,cntb;
ll lenb,blk[N+5];
ll fst[N+5],lst[N+5],dfn[N+5],dfnn;
ll xl[N+5],xll;
ll sum[N+5];
vector<ll> e[N+5];
ll lp=1,rp,s,now;
ll cntc[N+5],jl[N+5];
ll st[N+5][21],dep[N+5];
inline ll read() {
	ll x=0,f=1;
	char c=getchar();
	while (c<'0' || c>'9') { 
		if (c=='-')  f=-1; 
		c=getchar(); 
	} 
	while (c>='0' && c<='9') { 
		x=x*10+c-'0';
		c=getchar(); 
	} 
	return x*f;
}
bool cmp(node l1,node l2){
	if(blk[l1.l]==blk[l2.l]){
		if(blk[l1.r]==blk[l2.r]){
			return l1.t<l2.t;
		}
		if(blk[l1.l]&1) return l1.r<l2.r;
		return l1.r<l2.r;
	}
	return blk[l1.l]<blk[l2.l];
}
bool cmpf(node l1,node l2){
	return l1.f<l2.f;
}
ll get(ll col,ll x){
	return sum[x]*v[col];
}
void dfs(ll x,ll fa){
	fst[x]=++dfnn;
	dfn[dfnn]=x;
	dep[x]=dep[fa]+1;
	st[x][0]=fa;
	xl[x]=++xll;
	for(ll y:e[x]){
		if(y==fa) continue;
		dfs(y,x);
	}
	lst[x]=++dfnn;
	dfn[dfnn]=x;
	return;
}
void add(ll p){
	ll x=dfn[p],col=c[x];
	jl[x]++;
	if(jl[x]==2){
		s-=get(col,cntc[col]);
		cntc[col]--;
		s+=get(col,cntc[col]);
		return;
	}
	s-=get(col,cntc[col]);
	cntc[col]++;
	s+=get(col,cntc[col]);
	return;
}
void del(ll p){
	ll x=dfn[p],col=c[x];
	jl[x]--;
	if(jl[x]==1){
		s-=get(col,cntc[col]);
		cntc[col]++;
		s+=get(col,cntc[col]);
		return;
	}
	s-=get(col,cntc[col]);
	cntc[col]--;
	s+=get(col,cntc[col]);
	return;
}
void addd(ll col){
	s-=get(col,cntc[col]);
	cntc[col]++;
	s+=get(col,cntc[col]);
	return;
}
void dell(ll col){
	s-=get(col,cntc[col]);
	cntc[col]--;
	s+=get(col,cntc[col]);
	return;
}
void solve(ll l,ll r,ll t,ll &ans){
	while(l<lp){
		lp--;
		add(lp);
	}
	while(lp<l){
		del(lp);
		lp++;
	}
	while(rp<r){
		rp++;
		add(rp);
	}
	while(r<rp){
		del(rp);
		rp--;
	}
	while(now<t){
		now++;
//		cout<<l<<"oooooo"<<r<<endl;
		ll col=c[b[now].x],x=b[now].x;
		if(((l<=fst[x] && fst[x]<=r) || (l<=lst[x] && lst[x]<=r)) && jl[x]!=2){
			dell(col);
			swap(c[b[now].x],b[now].y);
			col=c[b[now].x];
			addd(col);
		}
		else swap(c[b[now].x],b[now].y);
	}
	while(now>t){
//		cout<<l<<"rrrrrrrrr"<<r<<endl;
		ll col=c[b[now].x],x=b[now].x;
		if(((l<=fst[x] && fst[x]<=r) || (l<=lst[x] && lst[x]<=r)) && jl[x]!=2){
			dell(col);
			swap(c[b[now].x],b[now].y);
			col=c[b[now].x];
			addd(col);
		}
		else swap(c[b[now].x],b[now].y);
		now--;
	}
//	cout<<lp<<"ooooooooooooooo"<<rp<<endl;
	ans=s;
	return;
}
ll lca(ll x,ll y){
	if(dep[x]<dep[y]) swap(x,y);
	for(ll i=20;i>=0;i--){
		ll xx=st[x][i];
		if(dep[xx]>=dep[y]) x=xx;
	}
	if(x==y) return x;
	for(ll i=20;i>=0;i--){
		ll xx=st[x][i],yy=st[y][i];
		if(xx!=yy) x=xx,y=yy;
	}
	return st[x][0];
}
int main(){
//	freopen("P4074_4.in","r",stdin);
//	freopen("hyxx.out","w",stdout);
	n=read(),m=read(),q=read();
	lenb=pow(n*2.0,2.0/3.0);
	ll sx=1;
	for(ll i=1;i<=2*n;i++){
		blk[i]=sx;
		if(i%lenb==0) sx++;
	}
	for(ll i=1;i<=m;i++) v[i]=read();
	for(ll i=1;i<=n;i++){
		w[i]=read();
		sum[i]=sum[i-1]+w[i];
	}
	for(ll i=1;i<n;i++){
		ll x,y;
		x=read(),y=read();
		e[x].push_back(y);
		e[y].push_back(x);
	}
	for(ll i=1;i<=n;i++) c[i]=read();
	for(ll i=1;i<=q;i++){
		ll op;
		op=read();
		if(op==0){
			cntb++;
			b[cntb].x=read(),b[cntb].y=read();
		}
		else{
			cnta++;
			a[cnta].l=read(),a[cnta].r=read();
			a[cnta].t=cntb,a[cnta].f=cnta;
		}
	}
	dfs(1,0);
	for(ll i=1;i<=20;i++){
		for(ll j=1;j<=n;j++){
			st[j][i]=st[st[j][i-1]][i-1];
//			cout<<j<<" "<<i<<" "<<st[4][0]<<endl;
		}
	}
//	for(ll i=1;i<=n;i++){
//		cout<<fst[i]<<"ooo"<<lst[i]<<endl;
//	}
//	return 0;
	sort(a+1,a+cnta+1,cmp);
	for(ll i=1;i<=cnta;i++){
		ll x=a[i].l,y=a[i].r,lcaa=lca(x,y),xx,yy,ff=0;
		if(lcaa==x){
			ff=1;
			xx=fst[x],yy=fst[y];
		}
		else if(lcaa==y){
			ff=1;
			xx=fst[y],yy=fst[x];
		}
		else{
			if(xl[x]>xl[y]) swap(x,y);
			xx=lst[x],yy=fst[y];
		}
//		cout<<xx<<" "<<yy<<" "<<ff<<endl;
		solve(xx,yy,a[i].t,a[i].ans);
		if(!ff){
			ll col=c[lcaa];
//			cout<<lcaa<<endl;
			a[i].ans+=w[cntc[col]+1]*v[col];
		}
//		printf("%lld\n",i);
	}
	sort(a+1,a+cnta+1,cmpf);
	for(ll i=1;i<=cnta;i++) printf("%lld\n",a[i].ans);
	return 0;
}
2025/8/1 22:08
加载中...