哈哈哈哈哈!!我用暴力LCA+枚举所有点+O2过了!!
查看原帖
哈哈哈哈哈!!我用暴力LCA+枚举所有点+O2过了!!
519384
Link_Cut_Y楼主2022/1/16 12:36
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cstdio>

using namespace std;

const int N = 10010, M = N << 1;
int h[N], ne[M], e[M], idx;
int top[N], dep[N], fa[N], son[N], sz[N];
int p[N], n, r, m;

void add(int a, int b)
{
	e[ ++ idx] = b, ne[idx] = h[a], h[a] = idx;
}

void dfs1(int u, int father, int depth)
{
	dep[u] = depth, fa[u] = father, sz[u] = 1;
	
	for (int i = h[u]; i; i = ne[i])
	{
		int j = e[i];
		if (j == father) continue;
		dfs1(j, u, depth + 1);
		sz[u] += sz[j];
		if (sz[son[u]] < sz[j]) son[u] = j;
	}
}

void dfs2(int u, int t)
{
	top[u] = t;
	if (!son[u]) return;
	dfs2(son[u], t);
	
	for (int i = h[u]; i; i = ne[i])
	{
		int j = e[i];
		if (j == fa[u] || j == son[u]) continue;
		dfs2(j, j);
	}
}

int LCA(int u, int v)
{
	while (top[u] != top[v])
	{
		if (dep[top[u]] < dep[top[v]]) swap(u, v);
		u = fa[top[u]];
	}
	if (dep[u] > dep[v]) swap(u, v);
	return u;
}

void init()
{
	for (int i = 1; i <= n; i ++ )
		for (int j = 1; j <= n; j ++ )
			p[LCA(i, j)] ++ ;
}

int main()
{
	cin >> n >> r >> m;
	
	for (int i = 1; i <= n - 1; i ++ )
	{
		int a, b;
		scanf("%d%d", &a, &b);
		add(a, b), add(b, a);
	}
	
	dfs1(r, 0, 1), dfs2(r, r);
	init();
	
	while (m -- )
	{
		int a;
		scanf("%d", &a);
		printf("%d\n", p[a]);
	}
	
	return 0;
}
2022/1/16 12:36
加载中...