0pts求助
查看原帖
0pts求助
800499
suzhikz楼主2024/12/6 15:50

summ维护的是概率和

#include<bits/stdc++.h>
#define ll long long
#define reg register
#define db double
#define il inline
#define int long long
using namespace std;
//void read(int &x){x=0;int f=1;char c=getchar();while(c>'9'||c<'0'){if(c=='-')f=-1;c=getchar();}while(c>='0'&&c<='9'){x=x*10+c-'0';c=getchar();}x*=f;}
void read(ll &x){x=0;int f=1;char c=getchar();while(c>'9'||c<'0'){if(c=='-')f=-1;c=getchar();}while(c>='0'&&c<='9'){x=x*10+c-'0';c=getchar();}x*=f;}
const int N=3e5+5,mod=998244353;
int n,rt[N];
vector<int>g[N];
int ls[N<<5],rs[N<<5];ll summ[N<<5],tot;
ll qpow(ll x,int y){
	ll re=1;
	while(y){
		if(y&1)re=re*x%mod;
		x=x*x%mod;y/=2;
	} 
	return re;
}
ll p[N];
ll lisan[N],tp;
void push_up(int x){
	summ[x]=(summ[ls[x]]+summ[rs[x]])%mod;
}
void update(int x,int l,int r,int pos,int w){
	if(l==r){
		summ[x]+=w;
		return;
	}
	int mid=(l+r)/2;
	if(pos<=mid){
		if(ls[x]==0)ls[x]=++tot;
		update(ls[x],l,mid,pos,w);
	}else{
		if(rs[x]==0)rs[x]=++tot;
		update(rs[x],mid+1,r,pos,w);
	}
}
ll query(int x,int l,int r,int ql,int qr){
	if(ql<=l&&qr>=r){
		return summ[x];
	}
	int mid=(l+r)/2;
	ll re=0;
	if(ls[x]&&ql<=mid)re=query(ls[x],l,mid,ql,qr);
	if(rs[x]&&qr>mid)re+=query(rs[x],mid+1,r,ql,qr);
	return re%mod;
}
int find(int x){
	int l=1,r=tp,mid,ans;
	while(l<=r){
		mid=(l+r)/2;
		if(lisan[mid]<=x){
			ans=x;
			l=mid+1;
		}else{
			r=mid-1;
		}
	}
	return ans;
}
int tag[N<<5];
void push_down(int x){
	if(tag[x]==0)return;
	summ[ls[x]]=summ[ls[x]]*tag[x]%mod;tag[ls[x]]=tag[x];
	summ[rs[x]]=summ[rs[x]]*tag[x]%mod;tag[rs[x]]=tag[x];
	tag[x]=0;
}
int merge(int x1,int x2,int l,int r,int pre1,int pre2,int p){
//	cout<<x1<<' '<<x2<<' '<<l<<' '<<r<<endl;
//cout<<p<<endl;
	ll _p=(1-p+mod)%mod;
	pre1%=mod;pre2%=mod;
	push_down(x1);push_down(x2); 
	if(x1==0||x2==0){
//		cout<<
		if(x1==0&&x2==0)return 0;
		if(x1==0){
//			cout<<12;
			summ[x2]=summ[x2]*p%mod*pre1%mod+summ[x2]*_p%mod*(1-pre1)%mod;
			tag[x2]=p*pre1%mod+(_p)*(1-pre1)%mod;
			return x2;
		}else{
			summ[x1]=summ[x1]*p%mod*pre2%mod+summ[x1]*(_p)%mod*(1-pre2)%mod;
			tag[x1]=p*pre2%mod+(_p)*(1-pre2)%mod;
			return x1;
		}
	}
	int noww=++tot;
	int mid=(l+r)/2;
	ls[noww]=merge(ls[x1],ls[x2],l,mid,pre1,pre2,p);
	rs[noww]=merge(rs[x1],rs[x2],mid+1,r,pre1+summ[ls[x1]],pre2+summ[ls[x2]],p);
	push_up(noww);
	return noww;
}
void dfs(int x){
//	cout<<x<<endl;
	for(int i:g[x]){
		dfs(i);
	}
//	cout<<x<<endl;
	if(g[x].size()==1)rt[x]=rt[g[x][0]];
	else if(g[x].size()==0){
		rt[x]=++tot;
		update(rt[x],1,tp,find(p[x]),1);
	}else{
//		cout<<g[x][0]<<' '<<g[x][1]<<endl;
		rt[x]=merge(rt[g[x][0]],rt[g[x][1]],1,tp,0,0,p[x]*qpow(10000,mod-2)%mod);
	}
}
signed main(){
	read(n);
	for(int u,i=1;i<=n;i++){
		read(u);
		if(u!=0){
			g[u].push_back(i);
		}
	}
	for(int i=1;i<=n;i++){
		read(p[i]);
		if(g[i].size()==0){
			lisan[++tp]=p[i];
		}
	}
	sort(lisan+1,lisan+1+tp);
	dfs(1);
	ll ans=0;
//	cout<<summ[4]<<endl;
	for(int i=1;i<=tp;i++){
//		cout<<query(rt[1],1,tp,i,i)<<endl; 
		ans=ans+i*query(rt[1],1,tp,i,i)%mod*query(rt[1],1,tp,i,i)%mod*lisan[i]%mod;
	}
	cout<<ans;
	return 0;
}

2024/12/6 15:50
加载中...