关于这题为什么纯树剖不可做。
下面的pushup函数和add函数其实一样的,脑子抽风分了两个。
#include<bits/stdc++.h>
//#define int long long
using namespace std;
int read() {
int x=0, w=1; char ch=getchar();
while(!isdigit(ch)) {if(ch=='-') w=-1; ch=getchar();}
while(isdigit(ch)) {x=x*10+(ch-'0'); ch=getchar();}
return x*w;
}
const int N=2e5+10;
const int inf=0x3f3f3f3f;
int n, a[N];
vector<int> g[N];
int tsp, rk[N];
struct node{int fa, dep, siz, son, top, dfn;} t[N];
void dfs1(int x, int fa) {
t[x].fa=fa; t[x].dep=t[fa].dep+1; t[x].siz=1;
for(int y:g[x]) {
if(y==fa) continue;
dfs1(y, x);
t[x].siz+=t[y].siz;
if(t[y].siz>t[t[x].son].siz) t[x].son=y;
}
}
void dfs2(int x, int top) {
t[x].top=top; t[x].dfn=++tsp; rk[tsp]=x;
if(!t[x].son) return ;
dfs2(t[x].son, top);
for(int y:g[x]) {
if(y==t[x].fa||y==t[x].son) continue;
dfs2(y, y);
}
}
struct trnode{
int l, r, lc, rc, c, sum;
set<int> s;
} tr[N<<2]; int trlen;
#define lc(x) tr[x].lc
#define rc(x) tr[x].rc
void pushup(int now) {
if(tr[lc(now)].c==tr[rc(now)].c) {
tr[now].c=tr[lc(now)].c;
for(auto t:tr[lc(now)].s) tr[now].s.insert(t);
for(auto t:tr[rc(now)].s) tr[now].s.insert(t);
for(auto t:tr[now].s) tr[now].sum+=t;
}
else if(tr[lc(now)].c>tr[rc(now)].c) {
tr[now].c=tr[lc(now)].c;
tr[now].sum=tr[lc(now)].sum;
for(auto t:tr[lc(now)].s) tr[now].s.insert(t);
}
else {
tr[now].c=tr[rc(now)].c;
tr[now].sum=tr[rc(now)].sum;
for(auto t:tr[rc(now)].s) tr[now].s.insert(t);
}
}
void build(int l, int r) {
int now=++trlen;
tr[now]={l, r, -1, -1, 0, 0};
if(l==r) {
tr[now].s.insert(a[rk[l]]);
tr[now].sum=a[rk[l]];
tr[now].c=1; return ;
}
int mid=(l+r)>>1;
tr[now].lc=trlen+1, build(l, mid);
tr[now].rc=trlen+1, build(mid+1, r);
pushup(now);
}
trnode add(trnode x, trnode y) {
trnode as={0, 0, -1, -1, 0, 0};
if(x.c==y.c) {
as.c=x.c;
for(auto t:x.s) as.s.insert(t);
for(auto t:y.s) as.s.insert(t);
for(auto t:as.s) as.sum+=t;
}
else if(x.c>y.c) {
as.c=x.c; as.sum=x.sum;
for(auto t:x.s) as.s.insert(t);
}
else {
as.c=y.c; as.sum=y.sum;
for(auto t:y.s) as.s.insert(t);
}
return as;
}
trnode query(int now, int l, int r) {
if(l<=tr[now].l&&r>=tr[now].r) return tr[now];
int mid=(tr[now].l+tr[now].r)>>1;
trnode ans={0, 0, -1, -1, 0, 0};
if(l<=mid) ans=add(ans, query(lc(now), l, r));
if(r>mid) ans=add(ans, query(rc(now), l, r));
return ans;
}
int main() {
n=read();
for(int i=1;i<=n;i++) a[i]=read();
for(int i=1,x,y;i<n;i++) {
x=read(), y=read();
g[x].push_back(y);
g[y].push_back(x);
}
dfs1(1, 0); dfs2(1, 1);
build(1, n);
for(int i=1;i<=n;i++) {
trnode tp=query(1, t[i].dfn, t[i].dfn+t[i].siz-1);
printf("%d ", tp.sum);
}
return 0;
}