代码:
#include <iostream>
#include <vector>
#define endl '\n'
using namespace std;
class SplayTree {
public:
void init(int N) {
root = tot = ttop = 0;
ch = vector<vector<int>>(N + 3, vector<int>(2, 0));
trush = val = cnt = siz = fa = vector<int>(N + 3, 0);
}
void push(int x) {
if (!root) {
int tmp = getnode();
val[tmp] = x, cnt[tmp] = 1;
pushup(tmp);
root = tmp;
return;
}
int cur = root, f = 0;
while (true) {
if (val[cur] == x) {
cnt[cur]++;
pushup(cur);
pushup(f);
splay(cur);
return;
}
f = cur;
cur = ch[cur][val[cur] < x];
if (!cur) {
int tmp = getnode();
val[tmp] = x, cnt[tmp] = 1, fa[tmp] = f, ch[f][val[f] < x] = tmp;
pushup(tmp);
pushup(f);
splay(tmp);
return;
}
}
}
void del(int x) {
if (!contains(x)) return;
getrank(x);
if (cnt[root] > 1) {
cnt[root]--;
pushup(root);
return;
}
if (ch[root][0]) {
if (ch[root][1]) {
int cur = root, pre = _getprev();
fa[ch[cur][1]] = pre;
ch[pre][1] = ch[cur][1];
clear(cur);
pushup(root);
} else {
int cur = root;
root = ch[cur][0];
fa[root] = 0;
clear(cur);
}
} else {
if (ch[root][1]) {
int cur = root;
root = ch[cur][1];
fa[root] = 0;
clear(cur);
} else {
clear(root);
root = 0;
}
}
}
int getrank(int x) {
if (!root) return 1;
int res = 0, cur = root;
while (true) {
if (!cur) return res + 1;
if (x < val[cur]) {
cur = ch[cur][0];
} else {
res += siz[ch[cur][0]];
if (x == val[cur]) {
splay(cur);
return res + 1;
}
res += cnt[cur];
cur = ch[cur][1];
}
}
return inf;
}
int getbyrank(int k) {
if (k < 1 || k > siz[root]) return inf;
int cur = root;
while (true) {
if (k <= siz[ch[cur][0]]) {
cur = ch[cur][0];
} else {
k -= siz[ch[cur][0]] + cnt[cur];
if (k <= 0) {
splay(cur);
return val[cur];
}
cur = ch[cur][1];
}
}
return inf;
}
int _getprev() {
if (!root) return -1;
if (!ch[root][0]) return -1;
int cur = ch[root][0];
while (ch[cur][1]) cur = ch[cur][1];
splay(cur);
return cur;
}
int getprev(int x) {
push(x);
int res = _getprev();
if (res < 0) return -inf;
del(x);
return val[res];
}
int _getnext() {
if (!root) return -1;
if (!ch[root][1]) return -1;
int cur = ch[root][1];
while (ch[cur][0]) cur = ch[cur][0];
splay(cur);
return cur;
}
int getnext(int x) {
push(x);
int res = _getnext();
if (res < 0) return inf;
del(x);
return val[res];
}
bool contains(int x) {
return root && getbyrank(getrank(x)) == x;
}
private:
const int inf = 0x7fffffff;
int root, tot, ttop;
vector<vector<int>> ch;
vector<int> trush, val, cnt, siz, fa;
void pushup(int x) { siz[x] = siz[ch[x][0]] + siz[ch[x][1]] + cnt[x]; }
int where(int x) { return ch[fa[x]][1] == x; }
void clear(int x) { val[x] = cnt[x] = siz[x] = ch[x][0] = ch[x][1] = fa[x] = 0; }
void rotate(int x) {
int y = fa[x], z = fa[fa[x]], chk = where(x);
ch[y][chk] = ch[x][chk ^ 1];
if (ch[x][chk ^ 1]) fa[ch[x][chk ^ 1]] = y;
ch[x][chk ^ 1] = y;
fa[y] = x;
fa[x] = z;
if (z) ch[z][ch[z][1] == y] = x;
pushup(y);
pushup(x);
}
void splay(int x) {
for (int f = fa[x]; f = fa[x], f; rotate(x))
if (fa[f]) rotate(where(f) == where(x) ? f : x);
root = x;
}
int getnode() {
if (ttop) return trush[ttop--];
return ++tot;
}
};
constexpr int N = 5e4 + 10, N4 = 2e5 + 10;
int n, m, a[N];
SplayTree e[N4];
void _build(int l, int r, int p) {
e[p].init(r - l + 1);
for (int i = l; i <= r; i++) e[p].push(a[i]);
if (l == r) return;
int mid = (l + r) >> 1;
_build(l, mid, p << 1);
_build(mid + 1, r, p << 1 | 1);
}
void build() {
_build(1, n, 1);
}
int _getrank(int l, int r, int p, int ql, int qr, int val) {
if (ql <= l && r <= qr) return e[p].getrank(val) - 1;
int mid = (l + r) >> 1, res = 0;
if (ql <= mid) res += _getrank(l, mid, p << 1, ql, qr, val);
if (qr > mid) res += _getrank(mid + 1, r, p << 1 | 1, ql, qr, val);
return res;
}
int getrank(int ql, int qr, int val) {
return _getrank(1, n, 1, ql, qr, val) + 1;
}
int getbyrank(int ql, int qr, int k) {
int l = 0, r = 1e8;
while (l < r) {
int mid = (l + r + 1) >> 1;
if (getrank(ql, qr, mid) > k) r = mid - 1;
else l = mid;
}
return l;
}
void _update(int l, int r, int p, int pos, int val) {
e[p].del(a[pos]); e[p].push(val);
if (l == r) return;
int mid = (l + r) >> 1;
if (pos <= mid) _update(l, mid, p << 1, pos, val);
if (pos > mid) _update(mid + 1, r, p << 1 | 1, pos, val);
}
void update(int pos, int val) {
_update(1, n, 1, pos, val);
a[pos] = val;
}
int _getprev(int l, int r, int p, int ql, int qr, int val) {
if (ql <= l && r <= qr) return e[p].getprev(val);
int mid = (l + r) >> 1;
int res = -0x7fffffff;
if (ql <= mid) res = max(res, _getprev(l, mid, p << 1, ql, qr, val));
if (qr > mid) res = max(res, _getprev(mid + 1, r, p << 1 | 1, ql, qr, val));
return res;
}
int getprev(int ql, int qr, int val) {
return _getprev(1, n, 1, ql, qr, val);
}
int _getnext(int l, int r, int p, int ql, int qr, int val) {
if (ql <= l && r <= qr) return e[p].getnext(val);
int mid = (l + r) >> 1;
int res = 0x7fffffff;
if (ql <= mid) res = min(res, _getnext(l, mid, p << 1, ql, qr, val));
if (qr > mid) res = min(res, _getnext(mid + 1, r, p << 1 | 1, ql, qr, val));
return res;
}
int getnext(int ql, int qr, int val) {
return _getnext(1, n, 1, ql, qr, val);
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cin >> n >> m;
for (int i = 1; i <= n; i++) cin >> a[i];
build();
while (m --> 0) {
int opt;
cin >> opt;
switch (opt) {
case 1: {
int l, r, k;
cin >> l >> r >> k;
cout << getrank(l, r, k) << endl;
break;
}
case 2: {
int l, r, k;
cin >> l >> r >> k;
cout << getbyrank(l, r, k) << endl;
break;
}
case 3: {
int pos, k;
cin >> pos >> k;
update(pos, k);
break;
}
case 4: {
int l, r, k;
cin >> l >> r >> k;
cout << getprev(l, r, k) << endl;
break;
}
case 5: {
int l, r, k;
cin >> l >> r >> k;
cout << getnext(l, r, k) << endl;
break;
}
}
}
return 0;
}
玄关