0分求助
查看原帖
0分求助
800499
suzhikz楼主2024/12/6 18:38
#include<bits/stdc++.h>
#define ll long long
#define reg register
#define db double
#define il inline
#define int long long
using namespace std;
void read(ll &x){x=0;int f=1;char c=getchar();while(c>'9'||c<'0'){if(c=='-')f=-1;c=getchar();}while(c>='0'&&c<='9'){x=x*10+c-'0';c=getchar();}x*=f;}
const int N=3e5+5,mod=998244353;
int n,rt[N];
vector<int>g[N];
int ls[N<<5],rs[N<<5];ll summ[N<<5],tot;
ll qpow(ll x,int y){
	ll re=1;
	while(y){
		if(y&1)re=re*x%mod;
		x=x*x%mod;y/=2;
	} 
	return re;
}
ll p[N];
ll lisan[N],tp;
int tag[N<<5];
void push_up(int x){
	summ[x]=(summ[ls[x]]+summ[rs[x]])%mod;
}
void push_down(int x){
	if(tag[x]==0)return;
//	cout<<x<<' '<<ls[x]<<endl<<rs[x]<<endl; 
	summ[ls[x]]=summ[ls[x]]*tag[x]%mod;tag[ls[x]]=tag[x];
	summ[rs[x]]=summ[rs[x]]*tag[x]%mod;tag[rs[x]]=tag[x];
	tag[x]=1;
}
void update(int x,int l,int r,int pos,int w){
	if(l==r){
		tag[x]=1;
		summ[x]+=1;
		return;
	}
	int mid=(l+r)/2;
	push_down(x);
	if(pos<=mid){
		if(ls[x]==0)ls[x]=++tot;
		update(ls[x],l,mid,pos,w);
	}else{
		if(rs[x]==0)rs[x]=++tot;
		update(rs[x],mid+1,r,pos,w);
	}
	push_up(x);
}
ll query(int x,int l,int r,int ql,int qr){
	if(ql<=l&&qr>=r){
		return summ[x];
	}
	push_down(x);
	int mid=(l+r)/2;
	ll re=0;
	if(ls[x]&&ql<=mid)re=query(ls[x],l,mid,ql,qr);
	if(rs[x]&&qr>mid)re+=query(rs[x],mid+1,r,ql,qr);
	return re%mod;
}
int find(int x){
	int l=1,r=tp,mid,ans=1;
	while(l<=r){
		mid=(l+r)/2;
		if(lisan[mid]<=x){
			ans=x;
			l=mid+1;
		}else{
			r=mid-1;
		}
	}
	return ans;
}
bool c=0;
int merge(int x1,int x2,int l,int r,int pre1,int pre2,int p){
//	if(c)cout<<x1<<' '<<x2<<' '<<l<<' '<<r<<' '<<pre1<<' '<<pre2<<' '<<p<<' '<<summ[8]<<endl;
	ll _p=(1-p+mod)%mod;
	pre1%=mod;pre2%=mod;
	push_down(x1);push_down(x2); 
//	cout<<summ[8]<<endl;
	if(x1==0||x2==0){
		if(x1==0){
			tag[x2]=p*pre1%mod+(_p)*(1-pre1+mod)%mod;tag[x2]%=mod;
			summ[x2]=summ[x2]*tag[x2]%mod;
			return x2;
		}else{
			tag[x1]=p*pre2%mod+(_p)*(1-pre2+mod)%mod;tag[x1]%=mod;
//			cout<<summ[x1]<<' '<<x1<<endl;
			summ[x1]=summ[x1]*tag[x1]%mod;
			return x1;
		}
		return 0;
	}
	int noww=++tot;
	int mid=(l+r)/2;
	rs[noww]=merge(rs[x1],rs[x2],mid+1,r,pre1+summ[ls[x1]],pre2+summ[ls[x2]],p);
	ls[noww]=merge(ls[x1],ls[x2],l,mid,pre1,pre2,p);
	push_up(noww);
	return noww;
}
void work(int x){
	if(ls[x])work(ls[x]);
	if(rs[x])work(rs[x]);
	if(summ[x]&&ls[x]==0&&rs[x]==0)cout<<summ[x]<<' '<<x<<"work"<<endl;
}
void dfs(int x){
	for(int i:g[x]){
		dfs(i);
	}
//	if(x==1)work(rt[2]);
	if(x==1)c=1;
	if(g[x].size()==1)rt[x]=rt[g[x][0]];
	else if(g[x].size()==0){
		rt[x]=++tot;
		update(rt[x],1,tp,find(p[x]),1);
	}else{
		rt[x]=merge(rt[g[x][0]],rt[g[x][1]],1,tp,0,0,p[x]*qpow(10000,mod-2)%mod);
	}
//	cout<<endl;
}
signed main(){
	read(n);
	for(int u,i=1;i<=n;i++){
		read(u);
		if(u!=0){
			g[u].push_back(i);
		}
	}
	for(int i=1;i<=n;i++){
		read(p[i]);
		if(g[i].size()==0){
			lisan[++tp]=p[i];
		}
	}
	sort(lisan+1,lisan+1+tp);
	dfs(1);
//	work(rt[1]);
	ll ans=0;
	for(int i=1;i<=tp;i++){
//		cnt=(cnt+query(rt[1],1,tp,i,i))%mod;
		ans=ans+i*query(rt[1],1,tp,i,i)%mod*query(rt[1],1,tp,i,i)%mod*lisan[i]%mod;ans%=mod;
	}
//	if(cnt!=1)cout<<"%";
	cout<<ans;
	return 0;
}
/*
7
0 1 2 2 1 5 5
5000 5000 1 2 500 3 4
*/

2024/12/6 18:38
加载中...