求助 splay
查看原帖
求助 splay
151601
Rusalka楼主2020/11/10 21:23

刚学 splay,普通版能过,这里连样例都过不去

#include <cstdio>

const int MAXN = 100010;
const int MAXM = 1000010;

int n, m, a[MAXN];

struct node{
	int siz, fa, ch[2], cnt, val;
}t[MAXN+MAXM];
int tot = 0, rt;
inline void update(int x)
{
	t[x].siz = t[t[x].ch[0]].siz + t[t[x].ch[1]].siz + t[x].cnt;
}
inline bool ident(int x)
{
	return t[t[x].fa].ch[1] == x;
}
inline void connect(int x, int fa, int k)
{
	t[fa].ch[k] = x;
	t[x].fa = fa;
}
inline void rotate(int x)
{
	int f = t[x].fa, ff = t[f].fa, k = ident(x);
	connect(t[x].ch[k^1], f, k);
	connect(x, ff, ident(f));
	connect(f, x, k^1);
	update(f);
	update(x);
}
void splay(int x, int y) // rotate x to the child of y; if y==0, means make x the root
{
	if(!y) rt = x;
	while(t[x].fa != y)
	{
		int f = t[x].fa, ff = t[f].fa;
		if(ff != y) rotate((ident(x)^ident(f))?x:f);
		rotate(x);
	}
}
inline void newnode(int& now, int fa, int v)
{
	t[now = ++tot].val = v;
	t[now].fa = fa;
	t[now].siz = t[now].cnt = 1;
	connect(now, fa, v>t[fa].val);
}
void ins(int val)
{
	int now = rt, lst = 0;
	while(now)
	{
		lst = now;
		if(t[now].val == val)
		{
			++t[now].cnt;
			splay(now, 0);
			return ;
		}
		now = t[now].ch[val>t[now].val];
	}
	newnode(now, lst, val);
	splay(now, 0);
}
inline void delnode(int x)
{
	splay(x, 0);
	if(t[x].cnt > 1) {
		--t[x].cnt;
		return ;
	}
	if(t[x].ch[1])
	{
		int p = t[x].ch[1]; 
		while(t[p].ch[0]) p = t[p].ch[0];
		splay(p, x);
		connect(t[x].ch[0], p, 0);
		rt = p;
		t[p].fa = 0;
		update(rt);
	}
	else rt = t[x].ch[0], t[rt].fa = 0;
}
inline void del(int val)
{
	int now = rt;
	while(now)
	{
		if(t[now].val == val) {
			delnode(now);
			return ;
		}
		now = t[now].ch[val>t[now].val];
	}
}
inline int rnk(int val)
{
	int now = rt, res = 1;
	while(now)
	{
//		printf("%d %d %d %d %d\n",now,res,t[now].val,t[t[now].ch[0]].val,t[t[now].ch[1]].val);
		if(t[now].val == val)
		{
			res += t[t[now].ch[0]].siz;
			splay(now, 0);
			break;
		}
		else if(val < t[now].val)
			now = t[now].ch[0];
		else
		{
			res += t[t[now].ch[0]].siz+t[now].cnt;
			now = t[now].ch[1];
		}
	}
	return res;
}
inline int kth(int k)
{
	int now = rt;
	while(now)
	{
//		printf("%d %d %d %d\n",now,t[now].val,t[t[now].ch[0]].val,t[t[now].ch[1]].val);
		int ls = t[t[now].ch[0]].siz;
		if(k >= ls+1 && k <= ls+t[now].cnt)
			break;
		else if(ls >= k)
			now = t[now].ch[0];
		else
		{
			k -= (ls+t[now].cnt);
			now = t[now].ch[1];
		}
	}
	return t[now].val;
}
inline int pre(int val)
{
	return kth(rnk(val)-1);
}
inline int suf(int val)
{
	return kth(rnk(val+1));
}

int main()
{
//	freopen("test.txt","w",stdout);
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;i++)
		scanf("%d",a+i), ins(a[i]);
	int ans = 0, sum = 0;
	while(m--)
	{
		int opt, x;
		scanf("%d%d",&opt,&x);
		x ^= ans;
		if(opt == 1) ins(x);
		else if(opt == 2) del(x);
		else if(opt == 3) ans = rnk(x);
		else if(opt == 4) ans = kth(x);
		else if(opt == 5) ans = pre(x);
		else if(opt == 6) ans = suf(x);
		printf("%d\n",ans);
		sum ^= ans;
	}
	printf("%d\n",sum);
	return 0;
}
2020/11/10 21:23
加载中...