WA on #4-#12 and #14 代码:
#include<bits/stdc++.h>
using namespace std;
const int N=1e6+5;
struct st
{
int son[2],fa,v,siz,sum;
void Init(int x,int y)
{
fa=x,v=y;siz=sum=1;
son[0]=son[1]=0;
}
}s[N];
int root,cnt,n,op,x;
void pushup(int now)
{
s[now].siz=s[s[now].son[0]].siz+s[s[now].son[1]].siz+s[now].sum;
}
void rotate(int x)
{
int y=s[x].fa,z=s[y].fa;
bool r=(x==s[y].son[1]);
s[y].son[r]=s[x].son[r^1];
if(s[x].son[r^1]) s[s[x].son[r^1]].fa=y;
s[x].son[r^1]=y;
s[y].fa=x;
s[x].fa=z;
if(z) s[z].son[(y==s[z].son[1])]=x;
pushup(y);pushup(x);
}
void splay(int x)
{
while(s[x].fa)
{
int y=s[x].fa,z=s[y].fa;
if(z)
{
if((s[y].son[1]==x)^(s[z].son[1]==y)) rotate(x);
else rotate(y);
}
rotate(x);
}
root=x;
}
int last(int x)
{
int now=root,ans=-1e9;
while(now)
{
if(s[now].v<x) ans=s[now].v,now=s[now].son[1];
else now=s[now].son[0];
}
return ans;
}
int next(int x)
{
int now=root,ans=1e9;
while(now)
{
if(s[now].v>x) ans=s[now].v,now=s[now].son[0];
else now=s[now].son[1];
}
return ans;
}
void ins(int x)
{
if(!root)
{
s[++cnt].Init(0,x);
root=cnt;
return;
}
int now=root,fa=0;
while(1)
{
if(s[now].v==x)
{
s[now].sum++;
pushup(now);
splay(now);
break;
}
fa=now,now=s[now].son[s[now].v<x];
if(!now)
{
s[++cnt].Init(fa,x);
s[fa].son[s[fa].v<x]=cnt;
pushup(fa);
splay(cnt);
break;
}
}
}
int ranks(int x)
{
int now=root,ans=0;
while(now)
{
if(s[now].v==x){splay(now);return ans+s[s[now].son[0]].siz+1;
}
if(x<s[now].v) now=s[now].son[0];
else ans+=s[s[now].son[0]].siz+s[now].sum,now=s[now].son[1];
}
return ans+1;
}
int kth(int x)
{
int now=root;
while(1)
{
if(x>s[s[now].son[0]].siz&&x<=s[s[now].son[0]].siz+s[now].sum) return s[now].v;
if(s[now].son[0]&&x<=s[s[now].son[0]].siz) now=s[now].son[0];
else x-=s[s[now].son[0]].siz+s[now].sum,now=s[now].son[1];
}
}
void del(int x)
{
ranks(x);
if(s[root].v!=x) return;
if(s[root].sum>1) s[root].sum--,pushup(root);
else
{
if(!s[root].son[0]&&!s[root].son[1]) root=0;
else if(!s[root].son[0]) root=s[root].son[1],s[root].fa=0;
else if(!s[root].son[1]) root=s[root].son[0],s[root].fa=0;
else
{
int l=s[root].son[0],now=root;
while(s[l].son[1]) l=s[l].son[1];
splay(l);
s[l].son[1]=s[now].son[1];
s[s[now].son[1]].fa=l;
root=l;
s[root].fa=0;
pushup(root);
}
}
}
int main()
{
ios::sync_with_stdio(0);
cin.tie(0);
cin>>n;
while(n--)
{
cin>>op>>x;
if(op==1) ins(x);
if(op==2) del(x);
if(op==3) cout<<ranks(x)<<'\n';
if(op==4) cout<<kth(x)<<'\n';
if(op==5) cout<<last(x)<<'\n';
if(op==6) cout<<next(x)<<'\n';
}
return 0;
}
另附用AI修改我代码后AC代码,但我拿他调试后找不到问题:
#include<bits/stdc++.h>
using namespace std;
const int N=1e6+10; // 扩大数组大小
struct Node {
int son[2],fa,v,siz,sum;
void Init(int x,int y) {
fa=x,v=y;siz=sum=1;
son[0]=son[1]=0;
}
}s[N];
int root,cnt,n,op,x;
void pushup(int now) {
s[now].siz=s[s[now].son[0]].siz+s[s[now].son[1]].siz+s[now].sum;
}
void rotate(int x) {
int y=s[x].fa,z=s[y].fa;
bool r=(x==s[y].son[1]);
s[y].son[r]=s[x].son[r^1];
if(s[x].son[r^1]) s[s[x].son[r^1]].fa=y;
s[x].son[r^1]=y;
s[y].fa=x;
s[x].fa=z;
if(z) s[z].son[(y==s[z].son[1])]=x;
pushup(y);pushup(x);
}
void splay(int x,int goal=0) {
while(s[x].fa!=goal) {
int y=s[x].fa,z=s[y].fa;
if(z!=goal)
((s[z].son[1]==y)^(s[y].son[1]==x))?rotate(x):rotate(y);
rotate(x);
}
if(!goal) root=x;
}
int find(int x) {
int now=root,res=0;
while(now) {
if(s[now].v==x) {res=now;break;}
now=s[now].son[s[now].v<x];
}
if(res) splay(res);
return res;
}
void ins(int x) {
if(!root) {
s[++cnt].Init(0,x);
root=cnt;
return;
}
int now=root,fa=0;
while(1) {
if(s[now].v==x) {
s[now].sum++;
pushup(now);pushup(fa);
splay(now);
break;
}
fa=now;
now=s[now].son[s[now].v<x];
if(!now) {
s[++cnt].Init(fa,x);
s[fa].son[s[fa].v<x]=cnt;
pushup(fa);
splay(cnt);
break;
}
}
}
int get_rank(int x) {
int now=root,res=0;
while(now) {
if(x<s[now].v) now=s[now].son[0];
else {
res+=s[s[now].son[0]].siz;
if(x==s[now].v) {splay(now);return res+1;}
res+=s[now].sum;
now=s[now].son[1];
}
}
return res+1;
}
int kth(int x) {
int now=root;
while(1) {
if(s[now].son[0]&&x<=s[s[now].son[0]].siz) now=s[now].son[0];
else {
int tmp=s[s[now].son[0]].siz+s[now].sum;
if(x<=tmp) return s[now].v;
x-=tmp;
now=s[now].son[1];
}
}
}
int get_pre(int x) {
int now=root,res=-1e9;
while(now) {
if(s[now].v<x) res=s[now].v,now=s[now].son[1];
else now=s[now].son[0];
}
return res;
}
int get_suc(int x) {
int now=root,res=1e9;
while(now) {
if(s[now].v>x) res=s[now].v,now=s[now].son[0];
else now=s[now].son[1];
}
return res;
}
void del(int x) {
int now=find(x);
if(!now) return;
if(s[now].sum>1) {
s[now].sum--;
pushup(now);
return;
}
if(!s[now].son[0]&&!s[now].son[1]) root=0;
else if(!s[now].son[0]) root=s[now].son[1],s[root].fa=0;
else if(!s[now].son[1]) root=s[now].son[0],s[root].fa=0;
else {
int tmp=s[now].son[0];
while(s[tmp].son[1]) tmp=s[tmp].son[1];
splay(tmp,now);
s[tmp].son[1]=s[now].son[1];
s[s[now].son[1]].fa=tmp;
root=tmp;
s[root].fa=0;
pushup(root);
}
}
int main() {
ios::sync_with_stdio(0);
cin.tie(0);
cin>>n;
while(n--) {
cin>>op>>x;
if(op==1) ins(x);
else if(op==2) del(x);
else if(op==3) cout<<get_rank(x)<<'\n';
else if(op==4) cout<<kth(x)<<'\n';
else if(op==5) cout<<get_pre(x)<<'\n';
else if(op==6) cout<<get_suc(x)<<'\n';
}
return 0;
}