求救,线段树合并,样例过不去WA20
查看原帖
求救,线段树合并,样例过不去WA20
578029
ivyjiao楼主2024/10/5 10:27
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=3e7+1,M=6e5+1;
int n,m,u,v,w,a[N],sum,ch[N][2],b[N],rt[N],dep[M],f[M][31],ans[M];
vector<int>G[M];
int merge(int u,int v,int l,int r){
    if(!u||!v) return u+v;
    if(l==r){
        b[u]+=b[v];
        return u;
    }
    int mid=(l+r)/2;
    ch[u][0]=merge(ch[u][0],ch[v][0],l,mid);
    ch[u][1]=merge(ch[u][1],ch[v][1],mid+1,r);
    return u;
}
int add(int u,int l,int r,int x,int k){
    if(!u) u=++sum;
    if(l==r){
        b[u]+=k;
        return u;
    }
    int mid=(l+r)/2;
    if(x<=mid) ch[u][0]=add(ch[u][0],l,mid,x,k);
    else ch[u][1]=add(ch[u][1],mid+1,r,x,k);
    return u;
}
int query(int u,int l,int r,int x){
    if(!u) return 0;
    if(l==r) return b[u];
    int mid=(l+r)/2;
    if(x<=mid) return query(ch[u][0],l,mid,x);
    else return query(ch[u][1],mid+1,r,x);
}
void dfs(int u){
    for(int i=1;i<=20;i++) f[u][i]=f[f[u][i-1]][i-1];
    for(int i=0;i<G[u].size();i++){
        int v=G[u][i];
        if(v==f[u][0]) continue;
        dep[v]=dep[u]+1;
        f[v][0]=u;
        dfs(v);
    }
}
int lca(int x,int y){
    if(x==y) return x;
    if(dep[x]<dep[y]) swap(x,y);
    for(int i=20;i>=0;i--) if(dep[f[x][i]]>=dep[y]) x=f[x][i];
    if(x==y) return x;
    for(int i=20;i>=0;i--) if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
    return f[x][0];
}
void cac(int u){
    for(int i=0;i<G[u].size();i++){
        int v=G[u][i];
        if(v==f[u][0]) continue;
        cac(v);
        rt[u]=merge(rt[u],rt[v],1,n*2);
    }
    if(a[u]&&n+dep[u]+a[u]<=n*2) ans[u]+=query(rt[u],1,n*2,n+dep[u]+a[u]);
    ans[u]+=query(rt[u],1,n*2,n+dep[u]+a[u]);
}
signed main(){
    cin>>n>>m;
    for(int i=1;i<n;i++){
        cin>>u>>v;
        G[u].push_back(v);
        G[v].push_back(u);
    }
    for(int i=1;i<=n;i++) cin>>a[i];
    dep[1]=1;
    dfs(1);
    while(m--){
        cin>>u>>v;
        int lc=lca(u,v);
        rt[u]=add(rt[u],1,n*2,n+dep[u],1);
        rt[v]=add(rt[v],1,n*2,n+dep[lc]*2-dep[u],1);
        rt[lc]=add(rt[lc],1,n*2,n+dep[u],-1);
        rt[f[lc][0]]=add(rt[f[lc][0]],1,n*2,n+dep[lc]*2-dep[u],-1);
    }
    cac(1);
    for(int i=1;i<=n;i++) cout<<ans[i]<<" ";
}
2024/10/5 10:27
加载中...