除了第三个点AC以为全部TLE,第四个点本地实测35s(没打错单位),请大佬们帮看看
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const ll N=2e6;
struct node{
ll l,r,f,t,ans;
}a[N+5];
struct node1{
ll x,y;
}b[N+5];
ll n,m,q;
ll c[N+5],v[N+5],w[N+5];
ll cnta,cntb;
ll lenb,blk[N+5];
ll fst[N+5],lst[N+5],dfn[N+5],dfnn;
ll xl[N+5],xll;
ll sum[N+5];
vector<ll> e[N+5];
ll lp=1,rp,s,now;
ll cntc[N+5],jl[N+5];
ll st[N+5][21],dep[N+5];
inline ll read() {
ll x=0,f=1;
char c=getchar();
while (c<'0' || c>'9') {
if (c=='-') f=-1;
c=getchar();
}
while (c>='0' && c<='9') {
x=x*10+c-'0';
c=getchar();
}
return x*f;
}
bool cmp(node l1,node l2){
if(blk[l1.l]==blk[l2.l]){
if(blk[l1.r]==blk[l2.r]){
return l1.t<l2.t;
}
if(blk[l1.l]&1) return l1.r<l2.r;
return l1.r<l2.r;
}
return blk[l1.l]<blk[l2.l];
}
bool cmpf(node l1,node l2){
return l1.f<l2.f;
}
ll get(ll col,ll x){
return sum[x]*v[col];
}
void dfs(ll x,ll fa){
fst[x]=++dfnn;
dfn[dfnn]=x;
dep[x]=dep[fa]+1;
st[x][0]=fa;
xl[x]=++xll;
for(ll y:e[x]){
if(y==fa) continue;
dfs(y,x);
}
lst[x]=++dfnn;
dfn[dfnn]=x;
return;
}
void add(ll p){
ll x=dfn[p],col=c[x];
jl[x]++;
if(jl[x]==2){
s-=get(col,cntc[col]);
cntc[col]--;
s+=get(col,cntc[col]);
return;
}
s-=get(col,cntc[col]);
cntc[col]++;
s+=get(col,cntc[col]);
return;
}
void del(ll p){
ll x=dfn[p],col=c[x];
jl[x]--;
if(jl[x]==1){
s-=get(col,cntc[col]);
cntc[col]++;
s+=get(col,cntc[col]);
return;
}
s-=get(col,cntc[col]);
cntc[col]--;
s+=get(col,cntc[col]);
return;
}
void addd(ll col){
s-=get(col,cntc[col]);
cntc[col]++;
s+=get(col,cntc[col]);
return;
}
void dell(ll col){
s-=get(col,cntc[col]);
cntc[col]--;
s+=get(col,cntc[col]);
return;
}
void solve(ll l,ll r,ll t,ll &ans){
while(l<lp){
lp--;
add(lp);
}
while(lp<l){
del(lp);
lp++;
}
while(rp<r){
rp++;
add(rp);
}
while(r<rp){
del(rp);
rp--;
}
while(now<t){
now++;
// cout<<l<<"oooooo"<<r<<endl;
ll col=c[b[now].x],x=b[now].x;
if(((l<=fst[x] && fst[x]<=r) || (l<=lst[x] && lst[x]<=r)) && jl[x]!=2){
dell(col);
swap(c[b[now].x],b[now].y);
col=c[b[now].x];
addd(col);
}
else swap(c[b[now].x],b[now].y);
}
while(now>t){
// cout<<l<<"rrrrrrrrr"<<r<<endl;
ll col=c[b[now].x],x=b[now].x;
if(((l<=fst[x] && fst[x]<=r) || (l<=lst[x] && lst[x]<=r)) && jl[x]!=2){
dell(col);
swap(c[b[now].x],b[now].y);
col=c[b[now].x];
addd(col);
}
else swap(c[b[now].x],b[now].y);
now--;
}
// cout<<lp<<"ooooooooooooooo"<<rp<<endl;
ans=s;
return;
}
ll lca(ll x,ll y){
if(dep[x]<dep[y]) swap(x,y);
for(ll i=20;i>=0;i--){
ll xx=st[x][i];
if(dep[xx]>=dep[y]) x=xx;
}
if(x==y) return x;
for(ll i=20;i>=0;i--){
ll xx=st[x][i],yy=st[y][i];
if(xx!=yy) x=xx,y=yy;
}
return st[x][0];
}
int main(){
// freopen("P4074_4.in","r",stdin);
// freopen("hyxx.out","w",stdout);
n=read(),m=read(),q=read();
lenb=pow(n*2.0,2.0/3.0);
ll sx=1;
for(ll i=1;i<=2*n;i++){
blk[i]=sx;
if(i%lenb==0) sx++;
}
for(ll i=1;i<=m;i++) v[i]=read();
for(ll i=1;i<=n;i++){
w[i]=read();
sum[i]=sum[i-1]+w[i];
}
for(ll i=1;i<n;i++){
ll x,y;
x=read(),y=read();
e[x].push_back(y);
e[y].push_back(x);
}
for(ll i=1;i<=n;i++) c[i]=read();
for(ll i=1;i<=q;i++){
ll op;
op=read();
if(op==0){
cntb++;
b[cntb].x=read(),b[cntb].y=read();
}
else{
cnta++;
a[cnta].l=read(),a[cnta].r=read();
a[cnta].t=cntb,a[cnta].f=cnta;
}
}
dfs(1,0);
for(ll i=1;i<=20;i++){
for(ll j=1;j<=n;j++){
st[j][i]=st[st[j][i-1]][i-1];
// cout<<j<<" "<<i<<" "<<st[4][0]<<endl;
}
}
// for(ll i=1;i<=n;i++){
// cout<<fst[i]<<"ooo"<<lst[i]<<endl;
// }
// return 0;
sort(a+1,a+cnta+1,cmp);
for(ll i=1;i<=cnta;i++){
ll x=a[i].l,y=a[i].r,lcaa=lca(x,y),xx,yy,ff=0;
if(lcaa==x){
ff=1;
xx=fst[x],yy=fst[y];
}
else if(lcaa==y){
ff=1;
xx=fst[y],yy=fst[x];
}
else{
if(xl[x]>xl[y]) swap(x,y);
xx=lst[x],yy=fst[y];
}
// cout<<xx<<" "<<yy<<" "<<ff<<endl;
solve(xx,yy,a[i].t,a[i].ans);
if(!ff){
ll col=c[lcaa];
// cout<<lcaa<<endl;
a[i].ans+=w[cntc[col]+1]*v[col];
}
// printf("%lld\n",i);
}
sort(a+1,a+cnta+1,cmpf);
for(ll i=1;i<=cnta;i++) printf("%lld\n",a[i].ans);
return 0;
}