求助线段树合并
  • 板块学术版
  • 楼主StarsIntoSea
  • 当前回复0
  • 已保存回复0
  • 发布时间2025/7/30 11:11
  • 上次更新2025/7/30 11:24:23
查看原帖
求助线段树合并
1121518
StarsIntoSea楼主2025/7/30 11:11

原题:CF600E

下面的代码正常插入一个值没有问题,但是在合并后,样例一第 4 个线段树会出奇怪错误,导致向上合并全错了。

如果不合并,每棵线段树都是对的。

具体在 52 行,一旦合并第 4 棵树就会似,往上合并也会错。

询问应该也有点问题,我自己调。

#include <stdio.h>
#define mid ((l+r)>>1)
#define ll long long
const int N=1e5+5;
int n,in[N],h[N],col[N],idx=0;
struct edge{int to,ne;}e[N*2];
struct node{ll res;int tmp;};
void add(int a,int b){
	e[++idx]={b,h[a]};
	h[a]=idx;
}
int rt[N],lc[N*200],rc[N*200],t[N*200],tot=0;
ll ans[N];
void insert(int u,int &p,int l,int r,int x){
	if(!t[p]) p=++tot; t[p]=1;
//	printf("%d %d %d %d\n",u,t[p],l,r);
	if(l==r) return ;
	if(x<=mid) insert(u,lc[p],l,mid,x);
	else insert(u,rc[p],mid+1,r,x);
}
int merge(int x,int y,int l,int r){
	if(!x||!y) return x+y;
	if(l==r) {t[x]+=t[y];return x;}
	lc[x]=merge(lc[x],lc[y],l,mid);
	rc[x]=merge(rc[x],rc[y],mid+1,r);
	t[x]=t[lc[x]]+t[rc[x]];
	return x;
}
node query(int u,int p,int l,int r){
	if(l==r) return (node){(ll)l,t[l]};
	if(!t[rc[p]]) return query(u,t[lc[p]],l,mid);
	else if(!t[lc[p]]) return query(u,t[rc[p]],mid+1,r);
	else{
		node r1=query(u,t[lc[p]],l,mid);
		node r2=query(u,t[rc[p]],mid+1,r);
		if(r1.tmp==r2.tmp) return (node){r1.res+r2.res,r1.tmp};
		else if(r1.tmp>r2.tmp) return r1;
		else return r2;
	}
}
void dfs(int u,int fa){
	insert(u,rt[u],1,n,col[u]);
	if(in[u]<=1&&u!=1){
		ans[u]=col[u];
		return ;
	}
	for(int i=h[u];i;i=e[i].ne){
		int v=e[i].to;
		if(v==fa) continue;
//		printf("%d %d\n",u,v);
		dfs(v,u);
		rt[u]=merge(rt[u],rt[v],1,n);
	}
	ans[u]=query(u,rt[u],1,n).res;
}

//void init(int u,int p,int l,int r){
//	if(!t[p]) return ;
//	printf("%d %d %d %d\n",u,t[p],l,r);
//	if(l==r) return ;
//	init(u,lc[p],l,mid);
//	init(u,rc[p],mid+1,r);
//}
int main(){
	scanf("%d",&n);
	for(int i=1;i<=n;++i) scanf("%d",&col[i]);
	for(int i=1;i<n;++i){
		int u,v; scanf("%d%d",&u,&v);
		add(u,v); add(v,u);
		in[u]++; in[v]++;
	}
	dfs(1,0);
//	for(int i=1;i<=n;++i) init(i,rt[i],1,n);
//	for(int i=1;i<=n;++i) printf("%d ",ans[i]);
//	printf("\n");
}
2025/7/30 11:11
加载中...