跪求大佬帮忙
查看原帖
跪求大佬帮忙
1442100
bizikang楼主2024/12/24 11:38
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const long long N = 2200000;
long long n, u, v, w;
long long h[N], num=1;
struct node{
	long long x, y, z, next;
}E[N];
void add(long long x, long long y, long long z){
	E[num].x = x;
	E[num].y = y;
	E[num].z = z;
	E[num].next = h[x];
	h[x] = num++;
}
long long dfn[N], dep[N], f[N][40], d[N][40], mi[40], s[N], cnt=0;
long long minn(long long a, long long b){
	if(a<b) return a;
	else return b;
}
void dfs(long long x, long long fa, long long w){
	f[x][0]=fa; d[x][0]=w; dfn[x]=++cnt; dep[x]=dep[fa]+1;
	for(long long i=1;i<=30;i++) f[x][i]=f[f[x][i-1]][i-1];
	for(long long i=1;i<=30;i++) d[x][i]=minn(d[f[x][i-1]][i-1], d[x][i-1]);
	for(long long i=h[x];i!=-1;i=E[i].next)
		if(E[i].y!=fa) dfs(E[i].y, x, E[i].z);
}
void swap(long long &x, long long &y){long long t=x;x=y;y=t;}
long long lca(long long x, long long y){
	if(dep[x]<dep[y]) swap(x, y);
	for(long long i=30;i>=0;i--)
		if(dep[x]-mi[i]>=dep[y])
			x=f[x][i];
	for(long long i=30;i>=0;i--)
		if(f[x][i]!=f[y][i]){
			x=f[x][i];
			y=f[y][i];
		}
	if(x!=y) return f[x][0];
	else return x;
}
long long dist(long long x, long long y){
	if(lca(x, y)!=x) swap(x, y);
	long long min1=0x7ffffffff;
	for(long long i=30;i>=0;i--)
		if(dep[y]-mi[i]>=dep[x]){
			if(d[y][i]<min1) min1=d[y][i];
			y=f[y][i];
		}
	return min1;
}
long long m, k, a[N], g[N];
bool cmp(long long a, long long b){return dfn[a]<dfn[b];}
node E1[N];
long long h1[N], num1=1;
void conn(long long x, long long y, long long z){
	E1[num1].x = x;
	E1[num1].y = y;
	E1[num1].z = z;
	E1[num1].next = h1[x];
	h1[x] = num1++;
}
long long dp[N];
void findans(long long x, long long fa, long long z){
	dp[x]=0; long long sum = 0;
	for(long long i=h1[x];i!=-1;i=E1[i].next){
		if(E1[i].y==fa) continue;
		findans(E1[i].y, x, E1[i].z);
		sum+=dp[E1[i].y];
	}
	if(s[x]) dp[x]=z;
	else dp[x]=minn(z, sum);
}
int main(){
//	freopen("in.in", "r", stdin);
	scanf("%lld", &n); mi[0]=1;
	for(long long i=1;i<=30;i++) mi[i]=2*mi[i-1];
	memset(h, -1, sizeof(h));
	memset(s,  0, sizeof(s));
	for(long long i=1;i<=n;i++)
		for(long long j=0;j<=30;j++)
			d[i][j]=0x7fffffff;
	for(long long i=1;i<n;i++){
		scanf("%lld%lld%lld", &u, &v, &w);
		add(u, v, w);
		add(v, u, w);
	}
	dfs(1, 0, 0);
	scanf("%lld", &m);
	memset(h1, -1, sizeof(h1));
	for(long long ii=1;ii<=m;ii++){
		scanf("%lld", &k);
		for(long long i=1;i<=k;i++){
			scanf("%lld", &a[i]);
			s[a[i]] = 1;
		}
		long long gn=0, lc; num1=1;
		a[k+1] = 1;
		sort(a+1, a+k+2, cmp);
		for(long long i=1;i<=k+1;i++){
			g[++gn] = a[i];
			if(i!=1) g[++gn]=lca(a[i], a[i-1]);
		}
		sort(g+1, g+gn+1, cmp);
		gn = unique(g+1, g+gn+1)-(g+1);
		for(long long i=1;i<gn;i++){
			lc=lca(g[i], g[i+1]);
			conn(lc, g[i+1], dist(lc, g[i+1]));
			conn(g[i+1], lc, dist(lc, g[i+1]));
		}
		findans(1, 0, 0x7ffffffff);
		printf("%lld\n", dp[1]);
		for(long long i=1;i<gn;i++){
			h1[lca(g[i], g[i+1])]=-1;
			h1[g[i+1]]=-1;
		}
		for(long long i=1;i<=k;i++) s[a[i]] = 0;
	}
	
} 

求大佬帮忙看看到底哪里出了问题,WA 5个点

2024/12/24 11:38
加载中...