求大家帮忙看看,全部测试点都re了....调了好几个小时了...
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
int const N = 1e5 + 10;
vector <LL> v;
int n, m;
LL a[N];
int h[N], e[N * 2], ne[N * 2], idx; // 邻接表数组
int dep[N], dfn[N], wson[N], sz[N], top[N], fa[N], num;
int cnt, root[N]; // 计数变量和根
struct Node {
int l, r; // 左儿子和右儿子
int sum; // 记录该节点的值
}hjt[N * 40];
void add(int a, int b) {
e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}
// 离散化查询
int getid(int x) {
return lower_bound(v.begin(), v.end(), x) - v.begin() + 1;
}
int lca(int u, int v) {
while (top[u] != top[v]) {
if (dep[top[u]] > dep[top[v]]) u = fa[top[u]];
else v = fa[top[v]];
}
return dep[u] > dep[v] ? v : u;
}
// 插入操作
void insert(int l, int r, int pre, int &now, int p) {
hjt[++cnt] = hjt[pre];
now = cnt; // 当前节点
hjt[now].sum++; // 更新该节点的值
if (l == r) return; // 如果递归到叶节点
int m = (l + r ) >> 1; // 找中间节点
if (p <= m) // 如果要插入的位置小于等于中间点,插入左子树
insert(l, m, hjt[pre].l, hjt[now].l, p);
else insert(m + 1, r, hjt[pre].r, hjt[now].r, p); // 往右子树插入
}
// 查询操作
int query(int l, int r, int this_lca, int fa_lca, int L, int R, int k) {
if (l == r) return l; // 递归到叶节点
int m = (l + r) >> 1; // 计算中点
int tmp = hjt[hjt[R].l].sum + hjt[hjt[L].l].sum - hjt[hjt[this_lca].l].sum - hjt[hjt[fa[fa_lca]].l].sum; // 计算左子树的包含的数字个数
if (k <= tmp) query(l, m, hjt[this_lca].l, hjt[fa_lca].l, hjt[L].l, hjt[R].l, k); // 如果k小于等于tmp,说明要查询的数字在左子树
else return query(m + 1, r, hjt[this_lca].r, hjt[fa_lca].r, hjt[L].r, hjt[R].r, k - tmp); // 否则在右子树上,注意这里要减去左子树的个数
}
// 得到sz, fa, dep, wson数组
void dfs1(int u) {
dep[u] = dep[fa[u]]+1;
insert(1, n, root[fa[u]], root[u], getid(a[u]));
sz[u] = 1;
for(int i = h[u]; ~i; i = ne[i]) {
int j=e[i];
if(j == fa[u]) continue;
fa[j] = u;
dfs1(j);
sz[u] += sz[j];
if(sz[j] > sz[wson[u]]) wson[u] = j; // 这里要注意根节点不能设为0,否则根节点的最重链无法更新,始终为0
}
}
// 得到dfn, top数组
void dfs2(int u, int nowtop) {
dfn[u] = ++num;
//以搜索序重排权值
top[u] = nowtop;
if(wson[u]) dfs2(wson[u], nowtop); // 先搜索重儿子
for(int i = h[u]; ~i; i = ne[i]) {// 然后搜索轻儿子
int y=e[i];
if(y ==fa[u]||y == wson[u]) continue;
dfs2(y, y);
}
}
int main() {
freopen("in.txt", "r", stdin);
memset(h, -1, sizeof h);
cin >> n >> m;
for (int i = 1; i <= n; ++i) {
scanf("%lld", &a[i]);
v.push_back(a[i]);
}
sort(v.begin(), v.end());
v.erase(unique(v.begin(), v.end()), v.end());
for (int i = 1, a, b; i <= n - 1; ++i) {
scanf("%d%d", &a, &b);
add(a, b), add(b, a);
}
n = v.size();
dfs1(1);
dfs2(1, 1);
LL last = 0;
for (int i = 1, x, y, k; i <= m; ++i) {
scanf("%d%d%d", &x, &y, &k);
int this_lca = lca(x, y), fa_lca = fa[this_lca];
x ^= last;
last = v[query(1, n, this_lca, fa_lca, root[x], root[y], k) - 1];
printf("%lld\n", last);
}
return 0;
}