如何让我的平衡树跑得飞快
  • 板块学术版
  • 楼主wanjiabao
  • 当前回复9
  • 已保存回复9
  • 发布时间2024/12/21 10:43
  • 上次更新2024/12/21 14:12:31
查看原帖
如何让我的平衡树跑得飞快
939957
wanjiabao楼主2024/12/21 10:43

rt,肯定是蒟蒻马蜂太诡异了,导致treap和splay在普通平衡树那道题只能跑300+ms,而班里一堆大佬都跑了200ms。

更让蒟蒻蚌埠住的是,加强版我的替罪羊树跑得比treap还快,treap12.76s,替罪羊10.10s

treap:

#include<bits/stdc++.h>
using namespace std;
struct node{
	node *ch[2];
	long long key;
	int cnt,siz,fix;
};
mt19937 mt(114514);
node *NIL,*root,*ncnt,tree[2000005];
void pushup(node *x){
	if(x==NIL)return;
	x->siz=x->ch[0]->siz+x->ch[1]->siz+x->cnt;
}
void rotate(node *&x,int d){
	node *y=x->ch[!d];
	x->ch[!d]=y->ch[d];
	y->ch[d]=x;
	x=y;
	pushup(x->ch[d]);
	pushup(x);
}
node *newnode(int val){
	node *p=++ncnt;
	p->key=val;
	p->fix=mt()%19260817;
	p->siz=p->cnt=1;
	p->ch[0]=p->ch[1]=NIL;
	return p;
}
void insert(node *&rt,int val){
	if(rt==NIL){
		rt=newnode(val);
		return;
	}
	if(rt->key==val){
		rt->cnt++;
		rt->siz++;
		return;
	}
	int d=(val>=rt->key);
	insert(rt->ch[d],val);
	if(rt->ch[d]->fix<rt->fix)rotate(rt,!d);
	pushup(rt);
}
void del(node *&rt,int val){
	if(rt->ch[0]==NIL&&rt->ch[1]==NIL){
		if(rt->key==val){
			rt->cnt--;
			rt->siz--;
			if(rt->cnt==0)rt=NIL;
			return;
		}
	}
	if(rt->key==val){
		if(rt->cnt>1){
			rt->cnt--;
			rt->siz--;
			return;
		}else{
			int d=(rt->ch[0]->fix<rt->ch[1]->fix);
			rotate(rt,d);
			del(rt->ch[d],val);
		}
	}else{
		int d=(val>=rt->key);
		del(rt->ch[d],val);
	}
	pushup(rt);
}
int getrank(node *rt, int val){
	if(rt==NIL)return 1;
	if(val==rt->key)return rt->ch[0]->siz+1;
	if(val<rt->key)return getrank(rt->ch[0],val);
	else return getrank(rt->ch[1],val)+rt->ch[0]->siz+rt->cnt;
}
node *select(node *rt,int k){
	if(rt==NIL)return NIL;
	if(k<=rt->ch[0]->siz){
		return select(rt->ch[0],k);
	}else if(k<=rt->ch[0]->siz+rt->cnt)return rt;
	else return select(rt->ch[1],k-rt->ch[0]->siz-rt->cnt);
}
node *findprev(int val){
	int k=getrank(root,val);
	node *p=select(root,k-1);
	return p;
}
node *findnext(int val){
	int k=getrank(root,val+1);
	node *p=select(root,k);
	return p;
}
int n,m,lst,ans;
signed main(){
	ios::sync_with_stdio(false);
	cin.tie(0);
	cout.tie(0);
	NIL=&tree[0];
	NIL->ch[0]=NIL->ch[1]=NIL;
	NIL->key=NIL->fix=INT_MAX;
	NIL->siz=NIL->cnt=0;
	ncnt=&tree[1];
	root=NIL;
	cin>>n>>m;
	for(int i=1,a;i<=n;i++){
		cin>>a;
		insert(root,a);
	}
	while(m--){
		int op,x;
		cin>>op>>x;
		x^=lst;
		if(op==1)insert(root,x);
		else if(op==2)del(root,x);
		else if(op==3)lst=getrank(root,x),ans^=lst;
		else if(op==4)lst=select(root,x)->key,ans^=lst;
		else if(op==5)lst=findprev(x)->key,ans^=lst;
		else lst=findnext(x)->key,ans^=lst;
	}
	cout<<ans<<endl;
}

splay:

#include<bits/stdc++.h>
using namespace std;
struct node{
	node *ch[2];
	node *fa;
	int key,siz;
}tree[2000005];
node *root,*NIL,*ncnt;
void init(){
	NIL=ncnt=&tree[0];
	root=NIL->ch[0]=NIL->ch[1]=NIL->fa=NIL;
}
node *newnode(int val){
	node *p=++ncnt;
	p->ch[0]=p->ch[1]=p->fa=NIL;
	p->key=val;
	p->siz=1;
	return p;
}
void pushup(node *x){
	if(x==NIL)return;
	x->siz=x->ch[0]->siz+x->ch[1]->siz+1; 
}
void rotate(node *x){
	node *y=x->fa;
	int d=(x==y->ch[0]);
	x->fa=y->fa;
	if(y->fa!=NIL)y->fa->ch[y->fa->ch[1]==y]=x;
	y->ch[!d]=x->ch[d];
	if(x->ch[d]!=NIL)x->ch[d]->fa=y;
	x->ch[d]=y;
	y->fa=x;
	if(root==y)root=x;
	pushup(x);
	pushup(y);
}
void splay(node *x,node *rt){
	node *y,*z;
	while(x->fa!=rt){
		y=x->fa;
		z=y->fa;
		if(z==rt)rotate(x);
		else{
			if((y==z->ch[0])^(x==y->ch[0]))rotate(x);
			else rotate(y);
			rotate(x);
		}
	}
}
void insert(node *&rt,node *p,int val){
	if(rt==NIL){
		rt=newnode(val);
		rt->fa=p;
		splay(rt,NIL);
		return;
	}
	rt->siz++;
	if(val<=rt->key)insert(rt->ch[0],rt,val);
	else insert(rt->ch[1],rt,val);
}
node *find(node *rt,int val){
	node *p;
	if(rt==NIL)return NIL;
	if(val==rt->key){
		p=rt->ch[0];
		while(p!=NIL&&p->key==val){
			rt=p;
			p=p->ch[0];
		}
		splay(rt,NIL);
		return rt;
	}
	if(val<rt->key)p=find(rt->ch[0],val);
	else p=find(rt->ch[1],val);
	if(p==NIL){
		p=rt;
		splay(rt,NIL);
	}
	return p;
}
node *findprev(node *rt,int val){
	if(rt==NIL)return NIL;
	node *p;
	if(val<=rt->key)p=findprev(rt->ch[0],val);
	else{
		p=findprev(rt->ch[1],val);
		if(p==NIL){
			p=rt;
			splay(p,NIL);
		}
	}
	return p;
}
node *findnext(node *rt,int val){
	if(rt==NIL)return NIL;
	node *p;
	if(val>=rt->key)p=findnext(rt->ch[1],val);
	else{
		p=findnext(rt->ch[0],val);
		if(p==NIL){
			p=rt;
			splay(p,NIL);
		}
	}
	return p;
}
int getrank(node *rt,int val){
	int res=0;
	node *rev;
    while(rt!=NIL){
        int d=rt->key<=val;
        if(d==1)res+=rt->ch[0]->siz+1;
        if(rt->ch[d]==NIL){
        	rev=rt;
		}
        rt=rt->ch[d];
    }
    splay(rev,NIL);
    return res;
}
node *select(int k,node *f){
	node *p=root;
	int lsz;
	while(true){
		lsz=p->ch[0]->siz;
		if(k==lsz+1)break;
		if(k<=lsz)p=p->ch[0];
		else p=p->ch[1],k-=(lsz+1);
	}
	splay(p,f);
	return p;
}
void del(node *rt,int val){
	int k=getrank(rt,val);
	node *p=select(k-1,NIL);
	node *q=select(k+1,p);
	q->ch[0]=NIL;
	splay(q,NIL);
}
int n,m,lst,ans;
int main(){
	ios::sync_with_stdio(false);
	cin.tie(0);
	cout.tie(0);
	init();
	insert(root,NIL,INT_MIN);
	insert(root,NIL,INT_MAX);
	cin>>n>>m;
	for(int i=1,a;i<=n;i++){
		cin>>a;
		insert(root,NIL,a);
	}
	while(m--){
		int op,x;
		cin>>op>>x;
		x^=lst;
		if(op==1)insert(root,NIL,x);
		else if(op==2)del(root,x);
		else if(op==3)lst=getrank(root,x-1),ans^=lst;
		else if(op==4)lst=select(x+1,NIL)->key,ans^=lst;
		else if(op==5)lst=findprev(root,x)->key,ans^=lst;
		else lst=findnext(root,x)->key,ans^=lst;
	}
	cout<<ans<<endl;
}

替罪羊:

#include<bits/stdc++.h>
#define lc t[id].ls
#define rc t[id].rs
using namespace std;
const double alpha=0.8;
struct SG{
	int ls,rs;
	int siz;
	int siv;
	int val;
	int weg;
}t[2000005];
int ltr[2000005],st,ncnt=1,rt=2e6;
bool check(int id){
	if(id!=rt&&(1.0*t[lc].siz>t[id].siz*alpha||1.0*t[rc].siz>t[id].siz*alpha))return 1;
	return 0;
}
void resolve(int id){
	t[id].siz=t[lc].siz+t[rc].siz+1;
	t[id].siv=t[lc].siv+t[rc].siv+t[id].weg;
}
void preorder(int id){
	if(lc)preorder(lc);
	if(t[id].weg)ltr[++st]=id;
	if(rc)preorder(rc);
}
int reunion(int l,int r){
	if(l>r)return -1;
	if(l==r){
		t[ltr[l]].ls=t[ltr[l]].rs=0;
		t[ltr[l]].siz=1;
		t[ltr[l]].siv=t[ltr[l]].weg;
		return ltr[l];
	}
	int mid=(l+r)>>1;
	int ls=reunion(l,mid-1),rs=reunion(mid+1,r);
	if(ls!=-1)t[ltr[mid]].ls=ls;
	else t[ltr[mid]].ls=0;
	if(rs!=-1)t[ltr[mid]].rs=rs;
	else t[ltr[mid]].rs=0;
	resolve(ltr[mid]);
	return ltr[mid];
}
int rebuild(int id){
	st=0;
	preorder(id);
	int root=reunion(1,st);
	return root;
}
void insert(int id,int fa,int val,bool op){
	if(id==0){
		++ncnt;
		t[ncnt].val=val;
		t[ncnt].siv=t[ncnt].weg=t[ncnt].siz=1;
		if(op==0)t[fa].ls=ncnt;
		else t[fa].rs=ncnt;
		return;
	}
	if(t[id].val==val){
		t[id].siv++;
		t[id].weg++;
		return;
	}
	if(val<t[id].val){
		insert(lc,id,val,0);
	}else{
		insert(rc,id,val,1);
	}
	resolve(id);
	if(check(id)){
		if(op==0)t[fa].ls=rebuild(id);
		else t[fa].rs=rebuild(id);
	}
}
void del(int id,int fa,int val,bool op){
	if(id==0)return;
	if(t[id].val==val){
		if(t[id].weg==0)return;
		t[id].weg--;
		t[id].siv--;
		return;
	}
	if(val<t[id].val){
		del(lc,id,val,0);
	}else{
		del(rc,id,val,1);
	}
	resolve(id);
	if(check(id)){
		if(op==0)t[fa].ls=rebuild(id);
		else t[fa].rs=rebuild(id);
	}
} 
int find_by_order(int id,int k){
	if(t[lc].siv>=k)return find_by_order(lc,k);
	else if(t[lc].siv+t[id].weg>=k)return t[id].val;
	else return find_by_order(rc,k-t[lc].siv-t[id].weg);
}
int order_of_key(int id,int val){
	if(id==0)return 0;
	if(t[id].val==val)return t[lc].siv;
	if(val<t[id].val)return order_of_key(lc,val);
	else return order_of_key(rc,val)+t[lc].siv+t[id].weg;
}
int find_prev(int val){
	int k=order_of_key(rt,val);
	return find_by_order(rt,k);
}
int find_next(int val){
	int k=order_of_key(rt,val+1);
	return find_by_order(rt,k+1);
}
int n,m,lst,ans;
int main(){
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
	cin>>n>>m;
	for(int i=1,a;i<=n;i++){
		cin>>a;
		insert(rt,0,a,0);
	}
	while(m--){
		int op,x;
		cin>>op>>x;
		x^=lst;
		if(op==1)insert(rt,0,x,0);
		else if(op==2)del(rt,0,x,0);
		else if(op==3)lst=order_of_key(rt,x)+1,ans^=lst;
		else if(op==4)lst=find_by_order(rt,x),ans^=lst;
		else if(op==5)lst=find_prev(x),ans^=lst;
		else lst=find_next(x),ans^=lst;
	}
	cout<<ans<<endl;
}

Tip:上述代码均能通过 P6136 【模板】普通平衡树(数据加强版)

2024/12/21 10:43
加载中...