单 log 求助卡常
查看原帖
单 log 求助卡常
637796
Xy_top楼主2024/12/1 13:12

线段树合并 + 3 次二维数点。

代码:

#include <bits/stdc++.h>
#define int long long
#define For(i, a, b) for (int i = (a); i <= (b); i ++)
#define foR(i, a, b) for (int i = (a); i >= (b); i --)
using namespace std;
int n, q, cnt, ccnt;
int fa[500005][19], ans[500005], dep[500005];
vector <int> G[500005], P[500005];
int ls[10000005], rs[10000005], root[500005], lmax[10000005], rmax[10000005];
int mx[2000005], st[500005][19], lg[500005];
struct Node {int L, R, res;}a[1000005];
bool cmp1 (Node n1, Node n2) {return n1.L < n2.L;};
bool cmp2 (Node n1, Node n2) {return n1.R - n1.L + 1 > n2.R - n2.L + 1;}
struct Query {int l, r, k, id;}u[500005];
bool cmpq1 (Query q1, Query q2) {return q1.l < q2.l;}
bool cmpq2 (Query q1, Query q2) {return q1.k > q2.k;}
int read () {
	char ch = getchar ();
	int x = 0;
	while (ch < '0' || ch > '9') ch = getchar ();
	while (ch >= '0' && ch <= '9') {
		x = x * 10 + ch - 48;
		ch = getchar ();
	}
	return x;
}
void dfs1 (int u) {
	For (i, 1, 18) fa[u][i] = fa[fa[u][i - 1] ][i - 1];
	st[u][0] = dep[u];
	for (int v : G[u]) {
		if (v == fa[u][0]) continue;
		fa[v][0] = u;
		dep[v] = dep[u] + 1;
		dfs1 (v);
	}
}
int lca (int u, int v) {
	if (dep[u] < dep[v]) swap (u, v);
	foR (i, 18, 0) if (dep[fa[u][i] ] >= dep[v]) u = fa[u][i];
	if (u == v) return u;
	foR (i, 18, 0) if (fa[u][i] != fa[v][i]) {
		u = fa[u][i];
		v = fa[v][i];
	}
	return fa[u][0];
}
void pushup (int l, int r, int k) {
	int mid = l + r >> 1;
	if (rmax[rs[k] ] == r - mid) rmax[k] = r - mid + rmax[ls[k] ];
	else rmax[k] = rmax[rs[k] ];
	if (lmax[ls[k] ] == mid - l + 1) lmax[k] = lmax[ls[k] ] + lmax[rs[k] ];
	else lmax[k] = lmax[ls[k] ];
}
int merge (int l, int r, int k1, int k2) {
	if (!k1 || !k2) return k1 | k2;
	if (l == r) {
		lmax[k2] = max (lmax[k1], lmax[k2]);
		rmax[k2] = max (rmax[k1], rmax[k2]);
		return k2;
	}
	int mid = l + r >> 1;
	ls[k2] = merge (l, mid, ls[k1], ls[k2]);
	rs[k2] = merge (mid + 1, r, rs[k1], rs[k2]);
	pushup (l, r, k2);
	return k2;
}
int query (int l, int r, int k, int x, int y, int type) {
	if (!k) return 0;
	if (x <= l && y >= r) return (type == 1 ? lmax[k] : rmax[k]);
	int mid = l + r >> 1;
	if (y <= mid) return query (l, mid, ls[k], x, y, type);
	if (x > mid) return query (mid + 1, r, rs[k], x, y, type);
	if (type == 1) {
		int xx = query (l, mid, ls[k], x, y, type);
		if (xx == mid - x + 1) return xx + query (mid + 1, r, rs[k], x, y, type);
		return xx;
	}
	int xx = query (mid + 1, r, rs[k], x, y, type);
	if (xx == y - mid) return xx + query (l, mid, ls[k], x, y, type);
	return xx;
}
void update (int l, int r, int &k, int x) {
	if (!k) k = ++ cnt;
	if (l == r) return void (lmax[k] = rmax[k] = 1);
	int mid = l + r >> 1;
	if (x <= mid) update (l, mid, ls[k], x);
	else update (mid + 1, r, rs[k], x);
	pushup (l, r, k);
}
void dfs2 (int u) {
	update (1, n, root[u], u);
	for (int v : G[u]) {
		if (v == fa[u][0]) continue;
		dfs2 (v);
		root[u] = merge (1, n, root[v], root[u]);
	}
	for (int i : P[u]) a[++ ccnt] = {i - query (1, n, root[u], 1, i, 2) + 1, i + query (1, n, root[u], i + 1, n, 1), dep[u]};
}
void modify (int l, int r, int k, int x, int y) {
	if (l == r) return void (mx[k] = max (mx[k], y) );
	int mid = l + r >> 1;
	if (x <= mid) modify (l, mid, k << 1, x, y);
	else modify (mid + 1, r, k << 1 | 1, x, y);
	mx[k] = max (mx[k << 1], mx[k << 1 | 1]);
}
int qmax (int l, int r, int k, int x, int y) {
	if (x > y) return 0;
	if (x <= l && y >= r) return mx[k];
	int mid = l + r >> 1, res = 0;
	if (x <= mid) res = qmax (l, mid, k << 1, x, y);
	if (y > mid) res = max (res, qmax (mid + 1, r, k << 1 | 1, x, y) );
	return res;
}
void func1 () {
	sort (a + 1, a + ccnt + 1, cmp1);
	sort (u + 1, u + q + 1, cmpq1);
	int tail = 0;
	For (i, 1, q) {
		if (u[i].k == 1) continue;
		while (tail != ccnt && a[tail + 1].L <= u[i].l) modify (1, n, 1, a[tail + 1].R, a[tail + 1].res), ++ tail;
		ans[u[i].id] = max (ans[u[i].id], qmax (1, n, 1, u[i].r, n) );
	}
}
void func2 () {
	sort (a + 1, a + ccnt + 1, cmp2);
	sort (u + 1, u + q + 1, cmpq2);
	For (i, 1, 2000000) mx[i] = 0;
	int tail = 0;
	For (i, 1, q) {
		if (u[i].k == 1) continue;
		while (tail != ccnt && a[tail + 1].R - a[tail + 1].L + 1 >= u[i].k) modify (1, n, 1, a[tail + 1].R, a[tail + 1].res), ++ tail;
		ans[u[i].id] = max (ans[u[i].id], qmax (1, n, 1, u[i].l + u[i].k - 1, u[i].r - 1) );
	}
	For (i, 1, 2000000) mx[i] = 0;
	tail = 0;
	For (i, 1, q) {
		if (u[i].k == 1) continue;
		while (tail != ccnt && a[tail + 1].R - a[tail + 1].L + 1 >= u[i].k) modify (1, n, 1, a[tail + 1].L, a[tail + 1].res), ++ tail;
		ans[u[i].id] = max (ans[u[i].id], qmax (1, n, 1, u[i].l + 1, u[i].r - u[i].k + 1) );
	}
}
signed main () {
	lg[1] = 0;
	For (i, 2, 500000) lg[i] = lg[i / 2] + 1;
	n = read ();
	For (i, 1, n) root[i] = i;
	cnt = n;
	For (i, 2, n) {
		int u = read (), v = read ();
		G[u].push_back (v);
		G[v].push_back (u);
	}
	dep[1] = 1;
	dfs1 (1);
	For (i, 1, n - 1) P[lca (i, i + 1)].push_back (i);
	dfs2 (1);
	q = read ();
	For (j, 1, 18) For (i, 1, n - (1 << j) + 1) st[i][j] = max (st[i][j - 1], st[i + (1 << j - 1)][j - 1]);
	For (i, 1, q) {
		u[i] = {read (), read (), read (), i};
		if (u[i].k == 1) {
			int l = lg[u[i].r - u[i].l + 1];
			ans[u[i].id] = max (st[u[i].l][l], st[u[i].r - (1 << l) + 1][l]);
		}
	}
	func1 ();
	func2 ();
	For (i, 1, q) printf ("%lld\n", ans[i]);
	return 0;
}
/*
6
2 1
3 2
4 3
5 1
6 5
1
1 4 3
[l, r, k],找 l >= L 且 R >= r 的 按照左端点从小到大加入即可
[l, r, k],找区间长度 >= k 的且 l + k - 1 <= R <= r,按照区间长度从大到小加入即可
[l, r, k],找区间长度 >= k 的且 l <= L <= r - k + 1,按照区间长度从大到小加入即可
*/
2024/12/1 13:12
加载中...