求助虚树 0pts WA + MLE
查看原帖
求助虚树 0pts WA + MLE
544860
linyihdfj楼主2022/2/13 12:13

求助大佬看看是不是哪里有细节问题

#include<bits/stdc++.h>
using namespace std;
const long long MAXN = 3e5;
const long long MAXM = 5e5;
const long long INF = 1e18;
struct edge{
	long long nxt,to,val;
	edge(long long a1=0,long long a2=0,long long a3=0){
		nxt=a1,to=a2,val=a3;
	}
};
long long n,q,top,cnt=1,cnt2=1,dfs_cnt,w[MAXN],fa[MAXN][25],dfn[MAXN],st[MAXN],head[MAXN],head2[MAXN],dep[MAXN],minv[MAXN];
edge e[3 * MAXM],e2[3 * MAXM];
bool query[MAXN];
void add_edge(long long from,long long to,long long val){
	e[++cnt] = edge(head[from],to,val);
	head[from] = cnt;
}
void dfs1(long long now){
	dfn[now] = ++dfs_cnt;
	for(long long i=head[now]; i; i=e[i].nxt){
		long long to = e[i].to;
		if(dfn[to])	continue;
		dep[to] = dep[now] + 1;
		minv[to] = min(minv[now],e[i].val);
		fa[to][0] = now;
		dfs1(to); 
	}
}
void chuli_lca(){
	for(long long i=1; i<=20; i++){
		for(long long j=1; j<=n; j++){
			fa[j][i] = fa[fa[j][i-1]][i-1];
		}
	}
}
long long get_lca(long long x,long long y){
	if(dep[x] > dep[y])
		swap(x,y);
	for(long long i=20; i>=0; i--){
		if(dep[fa[y][i]] >= dep[x])
			y = fa[y][i];
	}
	if(x == y)
		return y;
	for(long long i=20; i>=0; i--){
		if(fa[y][i] != fa[x][i]){
			y = fa[y][i];
			x = fa[x][i];
		}
	}
	return fa[y][0];
}
long long dfs2(long long now){
	long long sum = 0,temp = 0;
	for(long long i=head2[now]; i; i=e2[i].nxt){
		long long to = e2[i].to;
		sum += dfs2(to);
	}
	if(query[now])
		temp = minv[now];
	else
		temp = min(minv[now],sum);
	head2[now] = 0;
	query[now] = false;
	return temp;
}
void add2(long long from,long long to){
	e2[++cnt2] = edge(head2[from],to,0);
	head2[from] = cnt2;
//	prlong longf("虚树:%d - %d\n",from,to);
}
bool cmp(long long l,long long r){
	return dfn[l] < dfn[r];
}
int main(){
//	freopen("in.txt","r",stdin);
//	freopen("out.txt","w",stdout);
	cin>>n;
	for(long long i=1; i<n; i++){
		long long from,to,val;
		cin>>from>>to>>val;
		add_edge(from,to,val);
		add_edge(to,from,val);
	}
	minv[1] = INF;
	dfs1(1);
	cin>>q;
	while(q--){
//		cout<<endl;
		long long m;
		cin>>m;
		for(long long i=1; i<=m; i++){
			cin>>w[i];
			query[w[i]]=true;
		}
		sort(w+1,w+m+1,cmp);
		top = 0;
		st[++top] = w[1];
		for(long long i=2; i<=m; i++){
			long long now = w[i];
			long long lca = get_lca(now,st[top]);
			while(1){
				if(dep[lca] >= dep[st[top-1]]){
					if(lca != st[top]){
						add2(lca,st[top]);
						if(lca != st[top-1]){
							st[top] = lca;
						}
						else{
							top--;
						}
					}
					break;
				}
				else{
					add2(st[top-1],st[top]);
					top--;
				}
			}
			st[++top] = now;
		}
		while(--top){
			add2(st[top],st[top+1]);
		}
		printf("%lld\n",dfs2(st[1]));
		cnt2 = 1;
	}
	return 0;
} 
2022/2/13 12:13
加载中...