先将 a 序列排序,那么一个区间可以代表 trie 上一颗子树。
然后在 trie 上逐位考虑,维护一个子树对集合,每次计算当前位异或值为 1 的点对数量,决定往哪边走。
中间需要算上两棵子树之间点对异或值的和。预处理每一层 1 的个数前缀和,然后逐位算。
跑的飞快,至今没搞懂到底是 O(nlogv) 还是 O(nlog2v)?
#include <bits/stdc++.h>
using i64 = long long;
const int P = 1e9 + 7;
const int N = 50000;
struct Inter { int l, r, mid; };
struct Pair { Inter x, y; };
int a[N];
int s1[31][N + 1];
Pair f[N], g[N];
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
int n;
i64 m;
std::cin >> n >> m;
for (int i = 0; i < n; ++i) std::cin >> a[i];
std::sort(a, a + n);
m *= 2;
for (int i = 0; i <= 30; ++i)
for (int j = 1; j <= n; ++j)
s1[i][j] = s1[i][j - 1] + ((a[j - 1] >> i) & 1);
i64 ans = 0;
int cf = 0;
f[cf++] = Pair{Inter{0, n, -1}, Inter{0, n, -1}};
for (int b = 30; ~b; --b) {
i64 sum = 0;
for (int i = 0; i < cf; ++i) {
f[i].x.mid = std::lower_bound(a + f[i].x.l, a + f[i].x.r, 1 << b) - a;
f[i].y.mid = std::lower_bound(a + f[i].y.l, a + f[i].y.r, 1 << b) - a;
sum += (i64) (f[i].x.mid - f[i].x.l) * (f[i].y.r - f[i].y.mid);
sum += (i64) (f[i].x.r - f[i].x.mid) * (f[i].y.mid - f[i].y.l);
}
int cg = 0;
if (sum >= m) {
ans += m * (1 << b);
for (int i = 0; i < cf; ++i) {
if (f[i].x.l < f[i].x.mid && f[i].y.mid < f[i].y.r)
g[cg++] = Pair{Inter{f[i].x.l, f[i].x.mid, -1}, Inter{f[i].y.mid, f[i].y.r, -1}};
if (f[i].x.mid < f[i].x.r && f[i].y.l < f[i].y.mid)
g[cg++] = Pair{Inter{f[i].x.mid, f[i].x.r, -1}, Inter{f[i].y.l, f[i].y.mid, -1}};
}
} else {
ans += sum * (1 << b);
m -= sum;
// 下面这里比较迷惑
for (int j = 0; j < b; ++j) {
sum = 0;
for (int i = 0; i < cf; ++i) {
int x = s1[j][f[i].x.mid] - s1[j][f[i].x.l];
int y = s1[j][f[i].y.r] - s1[j][f[i].y.mid];
sum += (i64) x * (f[i].y.r - f[i].y.mid - y);
sum += (i64) (f[i].x.mid - f[i].x.l - x) * y;
x = s1[j][f[i].x.r] - s1[j][f[i].x.mid];
y = s1[j][f[i].y.mid] - s1[j][f[i].y.l];
sum += (i64) x * (f[i].y.mid - f[i].y.l - y);
sum += (i64) (f[i].x.r - f[i].x.mid - x) * y;
}
ans += sum * (1 << j);
}
for (int i = 0; i < cf; ++i) {
if (f[i].x.l < f[i].x.mid && f[i].y.l < f[i].y.mid)
g[cg++] = Pair{Inter{f[i].x.l, f[i].x.mid, -1}, Inter{f[i].y.l, f[i].y.mid, -1}};
if (f[i].x.mid < f[i].x.r && f[i].y.mid < f[i].y.r)
g[cg++] = Pair{Inter{f[i].x.mid, f[i].x.r, -1}, Inter{f[i].y.mid, f[i].y.r, -1}};
}
}
std::copy(g, g + cg, f);
cf = cg;
for (int i = 0; i < n; ++i) a[i] &= (1 << b) - 1;
}
ans = (ans / 2) % P;
std::cout << ans << '\n';
return 0;
}