关于树套树
  • 板块学术版
  • 楼主Autream
  • 当前回复3
  • 已保存回复4
  • 发布时间2025/1/13 21:11
  • 上次更新2025/1/14 11:03:55
查看原帖
关于树套树
941431
Autream楼主2025/1/13 21:11

兴致勃勃想学树套树,于是打了个板子,结果空间复杂度 O(n2)O(n^2) 甚至过不了编,外层线段树,内层 Splay,最后还是用 pbds 过的,有没有大佬帮忙改一下这份代码并解释一下原理。

正确性拍过。

马蜂有点丑勿喷。

// Problem: P3380 【模板】树套树
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P3380
// Memory Limit: 512 MB
// Time Limit: 2000 ms
// Date: 2025/01/12 19:07:54
// Author: Li_Feiy

#include <bits/stdc++.h>
#define arrout(a, n) rep(i, 1, n) printk(a[i])
#define arrin(a, n) rep(i, 1, n) a[i] = read()
#define rep(i, x, n) for(int i = x; i <= n; i++)
#define dep(i, x, n) for(int i = x; i >= n; i--)
#define erg(i, x) for(int i = head[x]; i; i = e[i].nex)
#define dbg(x) std::cout << #x << ":" << x << " "
#define mem(a, x) memset(a, x, sizeof a)
#define all(x) x.begin(), x.end()
#define arrall(a, n) a + 1, a + 1 + n
#define PII std::pair<int, int>
#define m_p std::make_pair
#define u_b upper_bound
#define l_b lower_bound
#define p_b push_back
#define CD const double
#define CI const int
#define int long long
#define il inline
#define ss second
#define ff first
#define itn int
int read() {
	char ch = getchar();
	int r = 0, w = 1;
	while(ch < '0' || ch > '9') w = ch == '-' ? -1 : w, ch = getchar();
	while(ch >= '0' && ch <= '9') r = (r << 3) + (r << 1) + (ch ^ 48), ch = getchar();
	return r * w;
}

void print(int x) {
	if(x < 0) putchar('-'), x = -x;
	if(x >= 10) print(x / 10);
	putchar(x % 10 + '0');
}template<typename ...Args>
void print(int t, Args... args) { print(t), print(args...); }

void printl(int x) { print(x), putchar('\n'); }
template<typename ...Args>
void printl(int t, Args... args) { printl(t), printl(args...); }

void printk(int x) { print(x), putchar(' '); }
template<typename ...Args>
void printk(int t, Args ... args) { printk(t), printk(args...); }

CI N = 1e3 + 5, INF = 2147483647;
int n, q, a[N];
struct Splay_Tree {
	int rt, tot, fa[N], cnt[N], val[N], size[N], ch[N][2];
	Splay_Tree() { insert(INF), insert(-INF); }
	void maintain(int k) { size[k] = size[ch[k][0]] + size[ch[k][1]] + cnt[k]; }
	int pos(int k) { return k == ch[fa[k]][1]; }
	void del(int k) { fa[k] = cnt[k] = val[k] = size[k] = ch[k][0] = ch[k][1] = 0; }
	void rotate(int k) {
		int y = fa[k], z = fa[y], op = pos(k);
		ch[y][op] = ch[k][!op];
		if(ch[k][!op]) fa[ch[k][!op]] = y;
		ch[k][!op] = y;
		fa[y] = k, fa[k] = z;
		if(z) ch[z][y == ch[z][1]] = k;
		maintain(y), maintain(k);
	}
	void splay(int k) {
		for(int f = fa[k]; f = fa[k], f; rotate(k))
			if(fa[f]) rotate(pos(f) == pos(k) ? f : k);
		rt = k;
	}
	void insert(int x) {
		if(!rt) {
			val[++tot] = x, cnt[tot] = 1;
			rt = tot;
			maintain(rt);
			return ;
		}
		int k = rt, f = fa[k];
		while(1) {
			if(val[k] == x) {
				cnt[k] ++;
				maintain(k), maintain(f);
				splay(k);
				return ;
			}
			f = k, k = ch[k][x > val[k]];
			if(!k) {
				val[++tot] = x, cnt[tot] = 1, fa[tot] = f, ch[f][x > val[f]] = tot;
				maintain(tot), maintain(f);
				splay(tot);
				return ;
			}
		}
	}
	int find(int x) {
		int k = rt;
		while(1) {
			if(x < val[k]) k = ch[k][0];
			else {
				if(!k) puts("ERROR"), exit(0);
				if(x == val[k]) { splay(k); return k; }
				k = ch[k][1];
			}
		}
	}
	int rank(int x) {
		int k = rt, rk = 0;
		while(1) {
			if(x < val[k]) k = ch[k][0];
			else {
				rk += size[ch[k][0]];
				if(!k) return rk;
				if(x == val[k]) { splay(k); return rk; }
				rk += cnt[k], k = ch[k][1];
			}
		}
	}
	int kth(int x) {
		x ++;
		int k = rt;
		while(1) {
			if(ch[k][0] && x <= size[ch[k][0]]) k = ch[k][0];
			else {
				if(x <= size[ch[k][0]] + cnt[k]) { splay(k); return val[k]; }
				x -= size[ch[k][0]] + cnt[k], k = ch[k][1];
			}
		}
	}
	int pre() {
		int k = ch[rt][0];
		if(!k) return k;
		while(ch[k][1]) k = ch[k][1];
		splay(k);
		return k;
	}
	int find_pre(int x) {
		int ans = -INF;
		insert(x), ans = pre(), remove(x);
		return val[ans];
	}
	int suc() {
		int k = ch[rt][1];
		if(!k) return k;
		while(ch[k][0]) k = ch[k][0];
		splay(k);
		return k;
	}
	int find_suc(int x) {
		int ans = INF;
		insert(x), ans = suc(), remove(x);
		return val[ans];
	}
	int find_max() {
		int k = rt;
		while(ch[k][1]) k = ch[k][1];
		return val[fa[k]];
	}
	void remove(int x) {
		rank(x);
		if(cnt[rt] > 1) { cnt[rt] --, maintain(rt); return ; }
		if(!ch[rt][0] && !ch[rt][1]) { del(rt), rt = 0; return ; }
		if(!ch[rt][0]) {
			int k = rt;
			rt = ch[rt][1], fa[rt] = 0;
			del(k);
			return ;
		}
		if(!ch[rt][1]) {
			int k = rt;
			rt = ch[rt][0], fa[rt] = 0;
			del(k);
			return ;
		}
		int k = rt, t = pre();
		ch[t][1] = ch[k][1];
		fa[ch[k][1]] = t;
		del(k);
		maintain(rt);
	}
} ;
struct Segment_Tree {
#define ls k << 1
#define rs k << 1 | 1
	Splay_Tree s[N << 1];
	void build(int k = 1, int l = 1, int r = n) {
		rep(i, l, r) s[k].insert(a[i]);
		if(l == r) return ;
		int mid = l + r >> 1;
		build(ls, l, mid), build(rs, mid + 1, r);
	}
	void update(int x, int v, int k = 1, int l = 1, int r = n) {
		s[k].remove(a[x]), s[k].insert(v);
		if(l == r) return ;
		int mid = l + r >> 1;
		if(x <= mid) update(x, v, ls, l, mid);
		else update(x, v, rs, mid + 1, r);
	}
	int query_max(int x, int y, int k = 1, int l = 1, int r = n) {
		if(y < l || x > r) return 0;
		if(x <= l && r <= y) return s[k].find_max();
		int mid = l + r >> 1, ans = -INF;
		if(x <= mid) ans = std::max(ans, query_max(x, y, ls, l, mid));
		if(y > mid) ans = std::max(ans, query_max(x, y, rs, mid + 1, r));
		return ans;
	}
	int query_pre(int x, int y, int v, int k = 1, int l = 1, int r = n) {
		if(y < l || x > r) return 0;
		if(x <= l && r <= y) return s[k].find_pre(v);
		int mid = l + r >> 1, ans = -INF;
		if(x <= mid) ans = std::max(ans, query_pre(x, y, v, ls, l, mid));
		if(y > mid) ans = std::max(ans, query_pre(x, y, v, rs, mid + 1, r));
		return ans;
	}
	int query_suc(int x, int y, int v, int k = 1, int l = 1, int r = n) {
		if(y < l || x > r) return 0;
		if(x <= l && r <= y) return s[k].find_suc(v);
		int mid = l + r >> 1, ans = INF;
		if(x <= mid) ans = std::min(ans, query_suc(x, y, v, ls, l, mid));
		if(y > mid) ans = std::min(ans, query_suc(x, y, v, rs, mid + 1, r));
		return ans;
	}
	int query_rank(int x, int y, int v, int k = 1, int l = 1, int r = n) {
		if(y < l || x > r) return 0;
		// printk(l, r, v, s[k].rank(v) - 1), puts("");
		if(x <= l && r <= y) return s[k].rank(v) - 1;
		int mid = l + r >> 1, ans = 0;
		if(x <= mid) ans += query_rank(x, y, v, ls, l, mid);
		if(y > mid) ans += query_rank(x, y, v, rs, mid + 1, r);
		return ans;
	}
#undef ls
#undef rs
} st;
int binary_search(int l, int r, int k) {
	int L = 0, R = 1e9, ans = 0;
	while(L <= R) {
		int mid = L + R >> 1;
		// printl(mid);
		if(st.query_rank(l, r, mid) + 1 <= k) L = mid + 1, ans = mid;
		else R = mid - 1;
	}
	return ans;
}
signed main() {
	n = read(), q = read();
	arrin(a, n);
	st.build();
	// Splay_Tree tr;
	// tr.insert(4), tr.insert(2), tr.insert(2), tr.insert(10);
	// printl(st.query_rank(1, 4, 2) + 1, tr.rank(2));
	while(q --) {
		int op = read(), l, r, k;
		switch(op) {
			case 1:
				l = read(), r = read(), k = read();
				printl(st.query_rank(l, r, k) + 1);
				break;
			case 2:
				l = read(), r = read(), k = read();
				printl(binary_search(l, r, k));
				break;
			case 3:
				l = read(), k = read();
				st.update(l, k), a[l] = k;
				break;
			case 4:
				l = read(), r = read(), k = read();
				printl(st.query_pre(l, r, k));
				break;
			case 5:
				l = read(), r = read(), k = read();
				printl(st.query_suc(l, r, k));
				break;
			default:
				break;
		}
	}
	return 0;
}
2025/1/13 21:11
加载中...