替罪羊树68pts求助
查看原帖
替罪羊树68pts求助
341650
tribool4_in楼主2021/8/21 18:28

rt

#include <bits/stdc++.h>
using namespace std;
const int N = 1e6 + 10;
const double alpha = 0.7;
inline int read() {
    int s=0, w=1;
    char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')w=-1;ch=getchar();}
    while(ch>='0'&&ch<='9') s=(s<<3)+(s<<1)+(ch^48),ch=getchar();
    return s*w;
}
int n, rt, cnt;
struct node {
    int l, r;
    int val, cnt, s, sm, sd;
} a[N];
void update(int x) {
    a[x].s = a[a[x].l].s + a[a[x].r].s + 1;
    a[x].sm = a[a[x].l].sm + a[a[x].r].sm + a[x].cnt;
    a[x].sd = a[a[x].l].sd + a[a[x].r].sd + (a[x].cnt != 0);
}
bool is_rebuild(int x) {
    return a[x].cnt && (alpha * a[x].s <= (double)(max(a[a[x].l].s, a[a[x].r].s)) || (double)(a[x].sd) <= alpha * a[x].s);
}
vector<int> v;
void dfs(int x) {
    if (!x) return;
    dfs(a[x].l);
    if (a[x].cnt) v.push_back(x);
    dfs(a[x].r);
}
int build(int l, int r) {
    if (l >= r) return 0;
    int mid = (l + r) >> 1;
    a[v[mid]].l = build(l, mid);
    a[v[mid]].r = build(mid + 1, r);
    update(v[mid]);
    return v[mid];
}
void rebuild(int &x) {
    v.clear();
    dfs(x);
    x = build(0, v.size());
}
void insert(int &x, int v) {
    if (!x) {
        x = ++cnt;
        if (!rt) rt = 1;
        a[x].l = a[x].r = 0;
        a[x].cnt = a[x].s = a[x].sd = a[x].sm = 1;
        a[x].val = v;
    } else {
        if (a[x].val == v) a[x].cnt++;
        else if (a[x].val < v) insert(a[x].r, v);
        else if (a[x].val > v) insert(a[x].l, v);
        update(x);
        if (is_rebuild(x)) rebuild(x);
    }
}
void erase(int &x, int v) {
    if (!x) return;
    if (a[x].val == v && a[x].cnt) a[x].cnt--;
    else if (a[x].val < v) erase(a[x].r, v);
    else if (a[x].val > v) erase(a[x].l, v);
    update(x);
    if (is_rebuild(x)) rebuild(x);
}
int find(int x, int v) {
    if (!x) return 1;
    if (a[x].val == v && a[x].cnt) return a[a[x].l].sm + 1;
    else if (a[x].val < v) return find(a[x].r, v) + a[a[x].l].sm + a[x].cnt;
    else if (a[x].val > v) return find(a[x].l, v);
}
int at(int x, int k) {
    if (!x) return 0;
    if (a[a[x].l].sm < k && k <= a[a[x].l].sm + a[x].cnt) return a[x].val;
    else if (k > a[a[x].l].sm + a[x].cnt) return at(a[x].r, k - a[a[x].l].sm - a[x].cnt);
    else if (a[a[x].l].sm >= k) return at(a[x].l, k);
}
int main() {
    n = read();
    while (n--) {
        int op = read(), x = read();
        if (op == 1) insert(rt, x);
        else if (op == 2) erase(rt, x);
        else if (op == 3) printf("%d\n", find(rt, x));
        else if (op == 4) printf("%d\n", at(rt, x));
        else if (op == 5) printf("%d\n", at(rt, find(rt, x) - 1));
        else if (op == 6) printf("%d\n", at(rt, find(rt, x + 1)));
    }
}
2021/8/21 18:28
加载中...