真的不知道错哪了QAQ
code
#include<bits/stdc++.h>
using namespace std;
const int N=1e6+10;
typedef set<int>::iterator it;
typedef long long ll;
int dfn[N*2],ndcnt;
struct line{
int to,nxt;
ll weight;
};
line edge[N*2];
int edge_cnt,head[N];
int ST[30][N*2],logn[N],depth[N*2];
int n,m;
ll weight[N];
bool inquery[N];
set<int> st;
inline void add(int u,int v,ll w){
edge[edge_cnt].to=v;
edge[edge_cnt].weight=w;
edge[edge_cnt].nxt=head[u];
head[u]=edge_cnt++;
}
void dfs(int root,int father){
// cout<<father<<"->"<<root<<endl;
dfn[root]=++ndcnt;
ST[0][dfn[root]]=root;
depth[root]=depth[father]+1;
// cout<<"ndcnt:"<<ndcnt<<" root:"<<root<<" dfn:"<<dfn[root]<<" st:"<<st[0][dfn[root]]<<"\n";
for(int v,w,i=head[root];~i;i=edge[i].nxt){
v=edge[i].to,w=edge[i].weight;
if(v!=father){
weight[v]=weight[root]+w;
dfs(v,root);
ST[0][++ndcnt]=root;
}
}
}
void init(){
for(int i=2;i<=ndcnt;i++)
logn[i]=logn[i>>1]+1;
for(int i=1;i<=logn[ndcnt];i++){
for(int j=1;j+(1<<i)-1<=ndcnt;j++){
ST[i][j]=( depth[ ST[i-1][j] ] < depth[ ST[i-1][j+(1<<i-1)] ] ? ST[i-1][j] : ST[i-1][j+(1<<i-1)]);
// cout<<"len:"<<(1<<i)<<" now: "<<j<<" st:"<<st[i][j]<<endl;
}
}
}
inline int lca(int u,int v){
if(u==v) return u;
if(dfn[u]>dfn[v]) swap(u,v);
int k=logn[dfn[v]-dfn[u]+1];
return ( depth[ ST[k][dfn[u]] ] < depth[ ST[k][dfn[v]-(1<<k)+1] ] ? ST[k][dfn[u]] : ST[k][dfn[v]-(1<<k)+1] );
}
#define cal(x,y) (weight[x]+weight[y]-2*weight[lca(x,y)])
inline int read(){
int f=1,x=0;
char c=getchar();
while(c<'0'||c>'9')<%if(c=='-') f=-1;c=getchar();%>
while(c>='0'&&c<='9')<%x=x*10+c-'0',c=getchar();%>
return f*x;
}
inline void print(int x){
if(x>9) print(x/10);
putchar(x%10+'0');
return;
}
ll ans;
int main(){
// freopen("1.in","r",stdin);
memset(head,-1,sizeof head);
n=read();m=read();
for(ll x,y,w,i=1;i<n;i++){
x=read(),y=read(),w=read();
add(x,y,w);
add(y,x,w);
}
dfs(1,0);
init();
for(int nd;m--;){
nd=read();
if(st.size()==0&&(!inquery[nd])) <%st.insert(dfn[nd]),inquery[nd]=1;cout<<ans<<"\n";continue;%>
if(st.size()==1&&inquery[nd]) <%st.erase(dfn[nd]),inquery[nd]=0,ans=0;cout<<ans<<"\n";continue;%>
it fir=st.upper_bound(dfn[nd]),sec=fir;
// cout<<"rw/"<<" fir: "<</*ST[0][*fir]*/(fir==st.begin() ? "begin" :(fir==st.end() ? "end" : "find"))<<endl;
fir=(fir==st.begin() ? fir : (fir==st.end() ? st.begin() : --fir) );
sec=(sec==st.begin() ? --st.end() : (sec==st.end() ? --sec : sec) );
// cout<<"nd: "<<nd<<" fir: "<<ST[0][*fir]<<" sec: "<<ST[0][*sec]<<endl;
ll dist=cal(ST[0][*fir],nd)+cal(nd,ST[0][*sec])-cal(ST[0][*fir],ST[0][*sec]);
if(inquery[nd]){
ans-=dist;
st.erase(dfn[nd]);
inquery[nd]=0;
}
else{
ans+=dist;
st.insert(dfn[nd]);
inquery[nd]=1;
}
cout<<ans<<"\n";
}
return 0;
}
求调QAQ