MnZn悬关test2第2个输出答案大于正确答案,对拍不出来
查看原帖
MnZn悬关test2第2个输出答案大于正确答案,对拍不出来
575363
revolutionary_oier楼主2024/11/19 16:06
#include<bits/stdc++.h>
#define int long long 
using namespace std;

const int maxn=1e5+10;
int T,n,cnt,tt;
int head[maxn],a[maxn],ti[maxn];
int c[2][2],d[2][2];
int f[maxn][2];
struct edge{
	int v,nxt;
}e[maxn<<1];
struct vertex{
	int mx,p1,se,p2;
}t[5][maxn];																								
inline void add(int u,int v){
	e[++cnt].v=v;
	e[cnt].nxt=head[u];
	head[u]=cnt;
}
inline void ipt(){
	scanf("%lld",&n);
	cnt=tt=0;
	for(int i=1;i<=n;i++)head[i]=0;
	for(int i=1;i<=n;i++){
		for(int j=0;j<=2;j++)t[j][i].mx=t[j][i].se=0,t[j][i].p1=t[j][i].p2=-1;
		for(int j=0;j<=1;j++)f[i][j]=0;
	}
	for(int i=1;i<=n;i++)scanf("%lld",&a[i]);
	for(int i=1;i<=n;i++)scanf("%lld",&ti[i]);
	for(int i=1;i<n;i++){
		int u,v;
		scanf("%lld%lld",&u,&v);
		add(u,v);
		add(v,u);
	}
}
inline void merge(int op,int u,int v){
	if(op==1){
		if(f[v][2]-f[v][0]+a[v]>t[op][u].mx){
			t[op][u].se=t[op][u].mx;
			t[op][u].p2=t[op][u].p1;
			t[op][u].mx=f[v][2]-f[v][0]+a[v];
			t[op][u].p1=v;
		}	
		else if(f[v][2]-f[v][0]+a[v]>t[op][u].se){
			t[op][u].se=f[v][2]-f[v][0]+a[v];
			t[op][u].p2=v;
		}
	}
	else if(op==2){
		if(f[v][1]-f[v][0]>t[op][u].mx){
			t[op][u].se=t[op][u].mx;
			t[op][u].p2=t[op][u].p1;
			t[op][u].mx=f[v][1]-f[v][0];
			t[op][u].p1=v;
		}	
		else if(f[v][1]-f[v][0]>t[op][u].se){
			t[op][u].se=f[v][1]-f[v][0];
			t[op][u].p2=v;
		}
	}
	else if(op==3){
		if(f[v][1]-f[v][0]>t[op][u].mx){
			t[op][u].se=t[op][u].mx;
			t[op][u].p2=t[op][u].p1;
			t[op][u].mx=f[v][1]-f[v][0];
			t[op][u].p1=v;
		}	
		else if(f[v][1]-f[v][0]>t[op][u].se){
			t[op][u].se=f[v][1]-f[v][0];
			t[op][u].p2=v;
		}
	}
}
inline void dp(int u,int fa){
	int x=0,y=0;
	for(int i=head[u];i;i=e[i].nxt){
		int v=e[i].v;
		if(v==fa)continue;
		dp(v,u);
		x+=f[v][1];
		y+=f[v][0];
		merge(1,u,v);
		if(ti[v]==3)merge(2,u,v);
		merge(3,u,v);
	}
	f[u][2]=y;
	f[u][0]=y;
	f[u][1]=y+a[u];
	int p=t[3][u].mx;
//	if(u==1)printf("= %lld %lld\n",y,p);
	f[u][1]=max(f[u][1],y+a[u]+p);
	f[u][0]=max(f[u][0],y+p);
	c[0][0]=t[2][u].mx;
	c[0][1]=t[2][u].p1;
	c[1][0]=t[2][u].se;
	c[1][1]=t[2][u].p2;
	d[0][0]=t[1][u].mx;
	d[0][1]=t[1][u].p1;
	d[1][0]=t[1][u].se;
	d[1][1]=t[1][u].p2;
	for(int i=0;i<=1;i++){
		if(c[i][1]==-1)continue;
		for(int j=0;j<=1;j++){
			if(d[j][1]==-1)continue;
			if(d[j][1]==c[i][1])continue;
//			printf("=> %lld %lld\n",c[i][1],d[j][1]);
			f[u][1]=max(f[u][1],y+a[u]+c[i][0]+d[j][0]);
			f[u][0]=max(f[u][0],y+c[i][0]+d[j][0]);
		}
	}
}
signed main(){
	scanf("%lld",&T);
	while(T--){
		ipt();
		dp(1,1);
//		printf("%lld\n",f[2][1]);
//		int x=f[1][1];
//		while(x){
//			printf("%lld",x%10);
//			x/=10;
//		}
//		printf("\n");
//		for(int i=1;i<=n;i++)printf("%lld %lld\n",f[i][0],f[i][1]);
		printf("%lld\n",f[1][1]);
	}
	return 0;
}
/*
1
5
1 10 100 1000 10000
1 2 1 1 1
1 2
1 3
2 4
2 5
*/
2024/11/19 16:06
加载中...