兴致勃勃想学树套树,于是打了个板子,结果空间复杂度 O(n2) 甚至过不了编,外层线段树,内层 Splay,最后还是用 pbds 过的,有没有大佬帮忙改一下这份代码并解释一下原理。
正确性拍过。
马蜂有点丑勿喷。
// Problem: P3380 【模板】树套树
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P3380
// Memory Limit: 512 MB
// Time Limit: 2000 ms
// Date: 2025/01/12 19:07:54
// Author: Li_Feiy
#include <bits/stdc++.h>
#define arrout(a, n) rep(i, 1, n) printk(a[i])
#define arrin(a, n) rep(i, 1, n) a[i] = read()
#define rep(i, x, n) for(int i = x; i <= n; i++)
#define dep(i, x, n) for(int i = x; i >= n; i--)
#define erg(i, x) for(int i = head[x]; i; i = e[i].nex)
#define dbg(x) std::cout << #x << ":" << x << " "
#define mem(a, x) memset(a, x, sizeof a)
#define all(x) x.begin(), x.end()
#define arrall(a, n) a + 1, a + 1 + n
#define PII std::pair<int, int>
#define m_p std::make_pair
#define u_b upper_bound
#define l_b lower_bound
#define p_b push_back
#define CD const double
#define CI const int
#define int long long
#define il inline
#define ss second
#define ff first
#define itn int
int read() {
char ch = getchar();
int r = 0, w = 1;
while(ch < '0' || ch > '9') w = ch == '-' ? -1 : w, ch = getchar();
while(ch >= '0' && ch <= '9') r = (r << 3) + (r << 1) + (ch ^ 48), ch = getchar();
return r * w;
}
void print(int x) {
if(x < 0) putchar('-'), x = -x;
if(x >= 10) print(x / 10);
putchar(x % 10 + '0');
}template<typename ...Args>
void print(int t, Args... args) { print(t), print(args...); }
void printl(int x) { print(x), putchar('\n'); }
template<typename ...Args>
void printl(int t, Args... args) { printl(t), printl(args...); }
void printk(int x) { print(x), putchar(' '); }
template<typename ...Args>
void printk(int t, Args ... args) { printk(t), printk(args...); }
CI N = 1e3 + 5, INF = 2147483647;
int n, q, a[N];
struct Splay_Tree {
int rt, tot, fa[N], cnt[N], val[N], size[N], ch[N][2];
Splay_Tree() { insert(INF), insert(-INF); }
void maintain(int k) { size[k] = size[ch[k][0]] + size[ch[k][1]] + cnt[k]; }
int pos(int k) { return k == ch[fa[k]][1]; }
void del(int k) { fa[k] = cnt[k] = val[k] = size[k] = ch[k][0] = ch[k][1] = 0; }
void rotate(int k) {
int y = fa[k], z = fa[y], op = pos(k);
ch[y][op] = ch[k][!op];
if(ch[k][!op]) fa[ch[k][!op]] = y;
ch[k][!op] = y;
fa[y] = k, fa[k] = z;
if(z) ch[z][y == ch[z][1]] = k;
maintain(y), maintain(k);
}
void splay(int k) {
for(int f = fa[k]; f = fa[k], f; rotate(k))
if(fa[f]) rotate(pos(f) == pos(k) ? f : k);
rt = k;
}
void insert(int x) {
if(!rt) {
val[++tot] = x, cnt[tot] = 1;
rt = tot;
maintain(rt);
return ;
}
int k = rt, f = fa[k];
while(1) {
if(val[k] == x) {
cnt[k] ++;
maintain(k), maintain(f);
splay(k);
return ;
}
f = k, k = ch[k][x > val[k]];
if(!k) {
val[++tot] = x, cnt[tot] = 1, fa[tot] = f, ch[f][x > val[f]] = tot;
maintain(tot), maintain(f);
splay(tot);
return ;
}
}
}
int find(int x) {
int k = rt;
while(1) {
if(x < val[k]) k = ch[k][0];
else {
if(!k) puts("ERROR"), exit(0);
if(x == val[k]) { splay(k); return k; }
k = ch[k][1];
}
}
}
int rank(int x) {
int k = rt, rk = 0;
while(1) {
if(x < val[k]) k = ch[k][0];
else {
rk += size[ch[k][0]];
if(!k) return rk;
if(x == val[k]) { splay(k); return rk; }
rk += cnt[k], k = ch[k][1];
}
}
}
int kth(int x) {
x ++;
int k = rt;
while(1) {
if(ch[k][0] && x <= size[ch[k][0]]) k = ch[k][0];
else {
if(x <= size[ch[k][0]] + cnt[k]) { splay(k); return val[k]; }
x -= size[ch[k][0]] + cnt[k], k = ch[k][1];
}
}
}
int pre() {
int k = ch[rt][0];
if(!k) return k;
while(ch[k][1]) k = ch[k][1];
splay(k);
return k;
}
int find_pre(int x) {
int ans = -INF;
insert(x), ans = pre(), remove(x);
return val[ans];
}
int suc() {
int k = ch[rt][1];
if(!k) return k;
while(ch[k][0]) k = ch[k][0];
splay(k);
return k;
}
int find_suc(int x) {
int ans = INF;
insert(x), ans = suc(), remove(x);
return val[ans];
}
int find_max() {
int k = rt;
while(ch[k][1]) k = ch[k][1];
return val[fa[k]];
}
void remove(int x) {
rank(x);
if(cnt[rt] > 1) { cnt[rt] --, maintain(rt); return ; }
if(!ch[rt][0] && !ch[rt][1]) { del(rt), rt = 0; return ; }
if(!ch[rt][0]) {
int k = rt;
rt = ch[rt][1], fa[rt] = 0;
del(k);
return ;
}
if(!ch[rt][1]) {
int k = rt;
rt = ch[rt][0], fa[rt] = 0;
del(k);
return ;
}
int k = rt, t = pre();
ch[t][1] = ch[k][1];
fa[ch[k][1]] = t;
del(k);
maintain(rt);
}
} ;
struct Segment_Tree {
#define ls k << 1
#define rs k << 1 | 1
Splay_Tree s[N << 1];
void build(int k = 1, int l = 1, int r = n) {
rep(i, l, r) s[k].insert(a[i]);
if(l == r) return ;
int mid = l + r >> 1;
build(ls, l, mid), build(rs, mid + 1, r);
}
void update(int x, int v, int k = 1, int l = 1, int r = n) {
s[k].remove(a[x]), s[k].insert(v);
if(l == r) return ;
int mid = l + r >> 1;
if(x <= mid) update(x, v, ls, l, mid);
else update(x, v, rs, mid + 1, r);
}
int query_max(int x, int y, int k = 1, int l = 1, int r = n) {
if(y < l || x > r) return 0;
if(x <= l && r <= y) return s[k].find_max();
int mid = l + r >> 1, ans = -INF;
if(x <= mid) ans = std::max(ans, query_max(x, y, ls, l, mid));
if(y > mid) ans = std::max(ans, query_max(x, y, rs, mid + 1, r));
return ans;
}
int query_pre(int x, int y, int v, int k = 1, int l = 1, int r = n) {
if(y < l || x > r) return 0;
if(x <= l && r <= y) return s[k].find_pre(v);
int mid = l + r >> 1, ans = -INF;
if(x <= mid) ans = std::max(ans, query_pre(x, y, v, ls, l, mid));
if(y > mid) ans = std::max(ans, query_pre(x, y, v, rs, mid + 1, r));
return ans;
}
int query_suc(int x, int y, int v, int k = 1, int l = 1, int r = n) {
if(y < l || x > r) return 0;
if(x <= l && r <= y) return s[k].find_suc(v);
int mid = l + r >> 1, ans = INF;
if(x <= mid) ans = std::min(ans, query_suc(x, y, v, ls, l, mid));
if(y > mid) ans = std::min(ans, query_suc(x, y, v, rs, mid + 1, r));
return ans;
}
int query_rank(int x, int y, int v, int k = 1, int l = 1, int r = n) {
if(y < l || x > r) return 0;
// printk(l, r, v, s[k].rank(v) - 1), puts("");
if(x <= l && r <= y) return s[k].rank(v) - 1;
int mid = l + r >> 1, ans = 0;
if(x <= mid) ans += query_rank(x, y, v, ls, l, mid);
if(y > mid) ans += query_rank(x, y, v, rs, mid + 1, r);
return ans;
}
#undef ls
#undef rs
} st;
int binary_search(int l, int r, int k) {
int L = 0, R = 1e9, ans = 0;
while(L <= R) {
int mid = L + R >> 1;
// printl(mid);
if(st.query_rank(l, r, mid) + 1 <= k) L = mid + 1, ans = mid;
else R = mid - 1;
}
return ans;
}
signed main() {
n = read(), q = read();
arrin(a, n);
st.build();
// Splay_Tree tr;
// tr.insert(4), tr.insert(2), tr.insert(2), tr.insert(10);
// printl(st.query_rank(1, 4, 2) + 1, tr.rank(2));
while(q --) {
int op = read(), l, r, k;
switch(op) {
case 1:
l = read(), r = read(), k = read();
printl(st.query_rank(l, r, k) + 1);
break;
case 2:
l = read(), r = read(), k = read();
printl(binary_search(l, r, k));
break;
case 3:
l = read(), k = read();
st.update(l, k), a[l] = k;
break;
case 4:
l = read(), r = read(), k = read();
printl(st.query_pre(l, r, k));
break;
case 5:
l = read(), r = read(), k = read();
printl(st.query_suc(l, r, k));
break;
default:
break;
}
}
return 0;
}