代码如下
放acwing里能过这里不行
#include <bits/stdc++.h>
using namespace std;
const int N = 5e5 + 10;
struct side
{
int x,y,w;
}rode[N];
int e[N * 2],ne[N * 2],h[N * 2],w[N * 2],idx;
int p[N],dep[N],f[N][32],W[N][32],lg[200];
bool st[N];
int n,m;
void add(int a,int b,int we)
{
e[idx] = b,ne[idx] = h[a],w[idx] = we,h[a] = idx ++;
}
int find(int x)
{
if(p[x] != x) return p[x] = find(p[x]);
return p[x];
}
bool cmp(side a,side b)
{
return a.w > b.w;
}
void krusrul()
{
sort(rode + 1,rode + 1 + m,cmp);
for(int i = 1;i <= m;i ++)
{
int p1 = find(rode[i].x),p2 = find(rode[i].y);
int x = rode[i].x,y = rode[i].y,w = rode[i].w;
if(p1 != p2)
{
p[p1] = p2;
add(x,y,w);
add(y,x,w);
}
}
}
void dfs(int u,int fa)
{
st[u] = 1;
if(fa != -1) dep[u] = dep[fa] + 1;
else dep[u] = 0;
for(int i = 1;i < lg[dep[u]];i ++)
{
f[u][i] = f[f[u][i - 1]][i - 1];
W[u][i] = min(W[u][i - 1],W[f[u][i - 1]][i - 1]);
}
for(int i = h[u];i != -1;i = ne[i])
{
int j = e[i];
if(j == fa) continue;
W[j][0] = w[i];
f[j][0] = u;
dfs(j,u);
}
}
int lca(int a,int b)
{
if(find(a) != find(b)) return -1;
int ans = 1e8;
if(dep[a] < dep[b]) swap(a,b);
while(dep[a] > dep[b])
{
ans = min(ans,W[a][lg[dep[a] - dep[b]] - 1]);
a = f[a][lg[dep[a] - dep[b]] - 1];
ans = min(ans,W[a][lg[dep[a] - dep[b]] - 1]);
}
if(a == b) return ans;
for(int k = lg[dep[a]] - 1; k >= 0; --k)
{
if(f[a][k] != f[b][k])
{
ans = min(ans,min(W[a][k],W[b][k]));
a = f[a][k], b = f[b][k];
}
}
ans = min(ans,min(W[a][0],W[b][0]));
return ans;
}
int main()
{
cin>>n>>m;
for(int i = 1;i <= m;i ++)
{
cin>>rode[i].x>>rode[i].y>>rode[i].w;
}
for(int i = 1;i <= n;i ++) p[i] = i;
memset(h,-1,sizeof h);
memset(W,0x3f,sizeof W);
krusrul();
lg[1] = 1;
for(int i = 1; i <= n; ++i)
lg[i] = lg[i-1] + (1 << lg[i-1] == i);
for(int i = 1;i <= n;i ++)
{
if(!st[i])
{
//W[i][0] = 1e9;
dfs(i,-1);
}
}
int q;
cin>>q;
while(q --)
{
int a,b;
cin>>a>>b;
cout<<lca(a,b)<<endl;
}
}