52分求助
查看原帖
52分求助
136321
天朝理科生楼主2021/8/15 15:32

第五个点WA了,6,7,8,9,10TLE了。

Splay代码:

#include <bits/stdc++.h>
using namespace std;
const int MAXN = 100010;

struct node {
	int fa, son[2], val, siz, rec;
} tr[MAXN];
int n, x, y, sz, rt;

void clear(int x) {
	tr[x].fa = tr[x].son[0] = tr[x].son[1] = tr[x].val = tr[x].siz = tr[x].rec = 0;
}

void update(int x) {
	if (x) {
		tr[x].siz = tr[x].rec;

		if (tr[x].son[0]) {
			tr[x].siz += tr[tr[x].son[0]].siz;
		}

		if (tr[x].son[1]) {
			tr[x].siz += tr[tr[x].son[1]].siz;
		}
	}
}

void connect(int x, int y, int z) {
	if (x) {
		tr[x].fa = y;
	}

	if (y) {
		tr[y].son[z] = x;
	}
}

int get(int x) {
	return tr[tr[x].fa].son[1] == x;
}

void rotate(int x) {
	int f = tr[x].fa, ff = tr[f].fa;
	int m = get(x), n = get(f);
	connect(tr[x].son[m ^ 1], f, m);
	connect(f, x, m ^ 1);
	connect(x, ff, n);
	update(f);
	update(x);
}

void splay(int x) {
	for (int f; f = tr[x].fa; rotate(x)) {

		if (tr[f].fa) {
			rotate(get(x) == get(f) ? f : x);
		}
	}

	rt = x;
}

void insert(int x) {
	if (rt == 0) {
		rt = ++sz;
		tr[sz].val = x;
		tr[sz].son[0] = tr[sz].son[1] = 0;
		tr[sz].rec = tr[sz].siz = 1;
		return;
	}

	int now = rt, f = 0;

	while (1) {
		if (tr[now].val == x) {
			tr[now].rec++;
			update(now);
			update(f);
			splay(now);
			return;
		}

		f = now;
		now = tr[f].son[x > tr[f].val];

		if (!now) {
			sz++;
			tr[sz].val = x;
			tr[sz].siz = tr[sz].rec = 1;
			tr[sz].son[0] = tr[sz].son[1] = 0;
			tr[sz].fa = f;
			tr[f].son[x > tr[f].val] = sz;
			update(f);
			splay(sz);
			return;
		}
	}
}

int find(int x) {
	int now = rt, ans = 0;

	while (1) {
		if (x < tr[now].val) {
			now = tr[now].son[0];
			continue;
		}

		ans += tr[tr[now].son[0]].siz;

		if (x == tr[now].val) {
			splay(now);
			return ans + 1;
		}

		ans += tr[now].rec;
		now = tr[now].son[1];
	}
}

int kth(int x) {
	int now = rt;

	while (1) {
		if (tr[now].son[0] && tr[tr[now].son[0]].siz >= x) {
			now = tr[now].son[0];
			continue;
		}

		if (tr[now].son[0]) {
			x -= tr[tr[now].son[0]].siz;
		}

		if (x <= tr[now].rec) {
			splay(now);
			return tr[now].val;
		}

		x -= tr[now].rec;
		now = tr[now].son[1];
	}
}

int pre() {
	int now = tr[rt].son[0];

	while (tr[now].son[1]) {
		now = tr[now].son[1];
	}

	return now;
}

int nxt() {
	int now = tr[rt].son[1];

	while (tr[now].son[0]) {
		now = tr[now].son[0];
	}

	return now;
}

void del(int x) {
	int useless = find(x);

	if (!tr[rt].son[0] && !tr[rt].son[1]) {
		clear(rt);
		rt = 0;
		return;
	}

	if (!tr[rt].son[0]) {
		int tmp = rt;
		rt = tr[rt].son[1];
		tr[rt].fa = 0;
		clear(tmp);
		return;
	}

	if (!tr[rt].son[1]) {
		int tmp = rt;
		rt = tr[rt].son[0];
		tr[rt].fa = 0;
		clear(tmp);
		return;
	}

	int tmp = rt, left = pre();
	splay(left);
	connect(tr[tmp].son[1], rt, 1);
	clear(tmp);
	update(rt);
}

void p(int x) {
	printf("%d ", tr[x].val);

	if (!tr[x].son[0] && !tr[x].son[1]) {
		return;
	}

	printf("[");

	if (tr[x].son[0]) {
		p(tr[x].son[0]);
	}

	if (tr[x].son[1]) {
		p(tr[x].son[1]);
	}

	printf("]");
	return;
}

int main() {
	freopen("P3369_5.in", "r", stdin);
	freopen("P3369_5.ans", "w", stdout);
	scanf("%d", &n);

	while (n--) {
		scanf("%d%d", &x, &y);

		switch (x) {
			case 1: {
				insert(y);
				break;
			}

			case 2: {
				del(y);
				break;
			}

			case 3: {
				printf("%d\n", find(y));
				break;
			}

			case 4: {
				printf("%d\n", kth(y));
				break;
			}

			case 5: {
				insert(y);
				printf("%d\n", tr[pre()].val);
				del(y);
				break;
			}

			case 6: {
				insert(y);
				printf("%d\n", tr[nxt()].val);
				del(y);
				break;
			}

			default: {
				break;
			}
		}

		//p(rt);
		//printf("\n");
	}

	return 0;
}
2021/8/15 15:32
加载中...