蒟蒻萌新求助基础LCA
查看原帖
蒟蒻萌新求助基础LCA
299883
HYdroKomide楼主2022/2/21 22:00
#include<cstdio>
#include<vector>
#define ri register int
using namespace std;
const int N=1e5+1,LG=16;
int n,m,x,y,l,dis,mid,d[N],f[N][LG],lg[N],sz[N],ans1,ans2;
vector<int>g[N];
void dfs(int x,int fa){
	d[x]=d[fa]+1;
	f[x][0]=fa;
	sz[x]=1;
	for(ri i=1;(1<<i)<=d[x];i++)f[x][i]=f[f[x][i-1]][i-1];
	for(ri i=0;i<g[x].size();i++)
		if(g[x][i]!=fa){
			dfs(g[x][i],x);
			sz[x]+=sz[g[x][i]];
		}
}
inline int lca(int x,int y){
	if(d[x]<d[y])swap(x,y);
	for(ri i=lg[d[x]];i>=0;i--)
		if(d[f[x][i]]>=d[y])
			x=f[x][i];
	if(x==y)return x;
	for(ri i=lg[d[x]];i>=0;i--)
		if(f[x][i]!=f[y][i])
			x=f[x][i],y=f[y][i];
	return f[x][0];
}
inline int find(int ori,int step){
	int ret=ori;
	if(step<0)return ret;
	for(ri i=15;i>=0;i--)
		if((step>>i)%2==1)
			ret=f[ret][i];
	return ret;
}
int main(){
    scanf("%d",&n);
    for(ri i=1;i<n;i++){
    	scanf("%d%d",&x,&y);
    	g[x].push_back(y);
    	g[y].push_back(x);
	}
	for(ri i=1;i<=n;i++)lg[i]=lg[i-1]+(i==1<<lg[i-1]);
	dfs(1,0);
	scanf("%d",&m);
	while(m--){
		scanf("%d%d",&x,&y);
		l=lca(x,y);
		dis=d[x]+d[y]-d[l]*2;
		if(x==y)printf("%d\n",n);
		else if(dis%2==1)printf("0\n");
		else if(d[x]==d[y]){
			ans1=find(x,d[x]-d[l]-1);
			ans2=find(y,d[y]-d[l]-1);
			printf("%d\n",n-sz[ans1]-sz[ans2]);
		}
		else{
			if(d[x]<d[y])swap(x,y);
			mid=find(x,dis/2);
			ans1=find(x,d[x]-d[mid]-1);
			printf("%d\n",sz[mid]-sz[ans1]);
		}
	}
    return 0;
}

RT,求找错qwq

2022/2/21 22:00
加载中...