为啥错了
查看原帖
为啥错了
800499
suzhikz楼主2024/12/24 20:58
#include<bits/stdc++.h>
#define ll long long
#define reg register
#define db double
#define il inline
using namespace std;
void read(int &x){x=0;int f=1;char c=getchar();while(c>'9'||c<'0'){if(c=='-')f=-1;c=getchar();}while(c>='0'&&c<='9'){x=x*10+c-'0';c=getchar();}x*=f;}
void read(ll &x){x=0;int f=1;char c=getchar();while(c>'9'||c<'0'){if(c=='-')f=-1;c=getchar();}while(c>='0'&&c<='9'){x=x*10+c-'0';c=getchar();}x*=f;}
const int N=2e5+5; 
int n,fa[N],vis[N];ll cnt0[N],cnt1[N];
ll ans;
struct node{
	ll c0,c1,id;
	friend bool operator <(node a,node b){
		a.c1*b.c0>b.c1*a.c0;
	}
};
priority_queue<node>q;
int p[N];
int find(int x){
	if(p[x]==x)return x;
	return p[x]=find(p[x]);
}
void merge(int x,int y){
	x=find(x);y=find(y);
	ans+=cnt1[y]*cnt0[x];cnt1[y]+=cnt1[x];cnt0[y]+=cnt0[x];
	p[x]=y;
}
int main(){
	read(n);
	for(int i=2;i<=n;i++)read(fa[i]);
	for(int u,i=1;i<=n;i++){
		read(u);p[i]=i;
		if(u==0)cnt0[i]=1;else cnt1[i]=1;
		q.push((node){cnt0[i],cnt1[i],i});
	}
	while(!q.empty()){
		int x=q.top().id;q.pop();
		if(vis[x])continue;
		vis[x]=1;
		if(fa[x]!=0){
			int y=find(fa[x]);
			merge(x,y);
			q.push((node){cnt0[y],cnt1[y],y});
		}
	}
	cout<<ans;
	return 0;
}

2024/12/24 20:58
加载中...