求助莫名RE
查看原帖
求助莫名RE
399936
辰云楼主2021/12/11 19:50
#include<iostream>
#include<cstdio>
#include<cstdlib>
#define ll long long
using namespace std;
inline ll read(){
	ll x=0,f=1;char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
	return x*f;
}


struct Node{
	Node* ch[2];
	int v;
	int r;
	int s;
	int flag;
	Node(int v){this->v=v;ch[0]=ch[1]=NULL;r=rand();s=1;flag=1;}
	bool operator < (const Node& rhs)const{
		return r<rhs.r;
	}
	int cmp(int x) const {
		if(x==v) return -1;
		return x<v ? 0 : 1;
	}
	void maintain(){
		s=flag;
		if(ch[0]!=NULL)s+=ch[0]->s;
		if(ch[1]!=NULL)s+=ch[1]->s;
	}
};

Node* root;

inline void turn(Node* &o, int d){
	Node* k=o->ch[d^1]; o->ch[d^1] = k->ch[d]; k->ch[d]=o;
	o->maintain(); k->maintain(); o=k;
}

inline void insert(Node* &o,int x){
	if(o==NULL) o=new Node(x);
	else {
		int d=o->cmp(x);
		if(d==-1) o->flag++;
		else {
			insert(o->ch[d],x);
			if(o->ch[d]->r > o->r) turn(o,d^1);
		}
	}
	o->maintain();
}

inline void del(Node* &o,int x){
	if(o==NULL) return ;
	int d=o->cmp(x);
	if(d==-1){
		if(o->flag > 1) o->flag--;
		else if(o->ch[0]==NULL&&o->ch[1]==NULL)o=NULL;
		else if(o->ch[0]==NULL)
			o=o->ch[1];
		else if(o->ch[1]==NULL)
			o=o->ch[0];
		else{
			int d2= (o->ch[0]->r > o->ch[1]->r ? 1 : 0);
			turn(o,d2);del(o->ch[d2],x);
		}
	}else 
		del(o->ch[d],x);
	o->maintain();
}

inline int getrank(Node* &o,int k){
	if(k<0||k>o->s||o==NULL) return 0;
	int ss=0;
	if(o->ch[0]!=NULL) ss=o->ch[0]->s;
	if(k>=ss+1&&k<=o->flag+ss)return o->v;
	if(ss>=k) return getrank(o->ch[0],k);
	else return getrank(o->ch[0],k - ss - o->flag);
}

inline int findrank(Node* &o,int k){
	if(o==NULL)return 0;
	int ss=0;
	if(o->ch[0]!=NULL) ss=o->ch[0]->s;
	if(o->v==k) return ss+1;
	if(o->v>k) return findrank(o->ch[0],k);
	else return findrank(o->ch[0],k) + ss + o->flag;
}

inline int pre(Node* o,int k){
	if(o==NULL)return 0;
	int ans=0;
	while(o!=NULL){
		if(o->v<k) ans=o->v,o=o->ch[1];
		else o=o->ch[0];
	}
	return ans;
}

inline int ne(Node* o,int k){
	if(o==NULL)return 0;
	int ans=0;
	while(o!=NULL){
		if(o->v>k) ans=o->v,o=o->ch[0];
		else o=o->ch[1];
	}
	return ans;
}

int main(){
	//int n=read();
	int n,op,x;
	scanf("%d",&n); 
	while(n--){
		//int op=read(),x=read();
		scanf("%d%d",&op,&x);
		if(op==1)insert(root,x);
		else if(op==2)del(root,x);
		else if(op==3)printf("%d\n",findrank(root,x));
		else if(op==4)printf("%d\n",getrank(root,x));
		else if(op==5)printf("%d\n",pre(root,x));
		else if(op==6)printf("%d\n",ne(root,x));
	}
	return 0;
}
2021/12/11 19:50
加载中...