第2,第8个点WA(按照真实数据大小排列应该是第7和第10),答案偏大。
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=5e5+10;
int n,m;
int tot,head[N],ver[N<<1],nxt[N<<1],edge[N];
int dep[N],fa[N][22],len[N][22];
int army[N];
struct Move{
int pl,lft;
}a[N],b[N];
int read(){
int x=0,f=1;
char ch=getchar();
while(!isdigit(ch)){
if(ch=='-')f=-1;
ch=getchar();
}
while(isdigit(ch)){
x=x*10+ch-'0';
ch=getchar();
}
return x*f;
}
void add(int x,int y,int z){
nxt[++tot]=head[x];
head[x]=tot;
ver[tot]=y;
edge[tot]=z;
}
void dfs(int x){
for(int i=1;i<=20;i++){
fa[x][i]=fa[fa[x][i-1]][i-1];
len[x][i]=len[x][i-1]+len[fa[x][i-1]][i-1];
}
for(int i=head[x];i;i=nxt[i]){
int y=ver[i];
int z=edge[i];
if(y==fa[x][0])continue;
fa[y][0]=x;
len[y][0]=z;
dfs(y);
}
}
int stay[N];
void mark(int x,int fa){
bool flag=true,son=false;
for(int i=head[x];i;i=nxt[i]){
int y=ver[i];
if(y==fa)continue;
mark(y,x);
son=true;
if(!stay[y])flag=false;
}
if(flag&&son)stay[x]++;
}
bool cmp1(Move x,Move y){
return x.lft>y.lft;
}
bool cmp2(Move x,Move y){
if(x.pl==y.pl)return x.lft>y.lft;
return x.pl<y.pl;
}
int need[N];
int use[N];
bool cmp3(int x,int y){
return x>y;
}
bool check(int maxtime){
memcpy(b,a,sizeof b);
memset(stay,0,sizeof stay);
memset(need,0,sizeof need);
memset(use,0,sizeof use);
for(int i=1;i<=m;i++){
b[i].lft=maxtime;
for(int j=20;j>=0;j--){
if(fa[b[i].pl][j]&&fa[b[i].pl][j]!=1&&len[b[i].pl][j]<=b[i].lft){
b[i].lft-=len[b[i].pl][j];
b[i].pl=fa[b[i].pl][j];
}
}
stay[b[i].pl]++;
}
mark(1,0);
sort(b+1,b+m+1,cmp1);
for(int i=1;i<=m;i++){
if(stay[b[i].pl]>1&&b[i].lft>=len[b[i].pl][0]){
stay[b[i].pl]--;
b[i].lft-=len[b[i].pl][0];
b[i].pl=fa[b[i].pl][0];
}
}
sort(b+1,b+m+1,cmp2);
int cnt1=0,cnt2=0;
for(int i=head[1];i;i=nxt[i]){
int y=ver[i];
if(!stay[y])need[++cnt1]=len[y][0];
}
sort(need+1,need+cnt1+1,cmp3);
for(int i=1;i<=m;i++)if(b[i].pl==1)use[++cnt2]=b[i].lft;
for(int i=1,j=1;i<=cnt1;i++,j++)if(use[j]<need[i])return false;
return true;
}
signed main(){
freopen("blockade10.in","r",stdin);
n=read();
for(int i=1;i<n;i++){
int x=read(),y=read(),z=read();
add(x,y,z);
add(y,x,z);
}
dfs(1);
m=read();
for(int i=1;i<=m;i++)a[i].pl=read();
int l=0,r=1e15,mid;
while(l<r){
mid=(l+r)>>1;
if(check(mid))r=mid;
else l=mid+1;
}
printf("%lld\n",r);
return 0;
}