RT,思路是第一篇题解的。目前只写了k=3的部分,但测大样例时50000个询问错3个。小样例通过了。 万分感谢。
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N=200010;
int n,q,k,v[N],num[N],a,b,s,t,dep[N],lca[N][19],num2[N];
vector<int>g[N];
inline int read(){
int s=0;
bool f=0;
char c=getchar();
while(!isdigit(c)){
if(c=='-')f^=1;
c=getchar();
}
while(isdigit(c)){
s=(s<<1)+(s<<3)+(c^48);
c=getchar();
}
return f==0?s:-s;
}
struct MATRIX{
int n,m,a[3][3];
MATRIX operator*(MATRIX b){
MATRIX c;
c.n=n,c.m=b.m;
memset(c.a,63,sizeof(c.a));
for(int i=0;i<c.n;++i)
for(int j=0;j<c.m;++j)
for(int k=0;k<m;++k)
c.a[i][j]=min(c.a[i][j],a[i][k]+b.a[k][j]);
return c;
}
}base[N],stn[N][19],sts[N][19];
void dfs(int r,int f){
lca[r][0]=f;
dep[r]=dep[f]+1;
for(int i:g[r])
if(i!=f)dfs(i,r);
}
signed main(){
// freopen("test.in","r",stdin);
// freopen("test.out","w",stdout);
n=read(),q=read(),k=read();
memset(num,63,sizeof(num));
const int inf=num[0];
for(int i=1;i<=n;++i)v[i]=read();
for(int i=1;i<n;++i){
a=read(),b=read();
g[a].push_back(b);
g[b].push_back(a);
num[a]=min(num[a],v[b]);
num[b]=min(v[a],num[b]);
}
for(int i=1;i<=n;++i){
base[i].n=base[i].m=3;
base[i].a[0][0]=v[i];
base[i].a[0][1]=v[i];
base[i].a[0][2]=v[i];
base[i].a[1][0]=0;
base[i].a[1][1]=num[i];
base[i].a[1][2]=inf;
base[i].a[2][0]=inf;
base[i].a[2][1]=0;
base[i].a[2][2]=inf;
}
for(int i=1;i<=n;++i)stn[i][0]=sts[i][0]=base[i];
dfs(1,0);
for(int i=1;i<19;++i)
for(int j=1;j<=n;++j){
lca[j][i]=lca[lca[j][i-1]][i-1];
sts[j][i]=sts[lca[j][i-1]][i-1]*sts[j][i-1];
stn[j][i]=stn[j][i-1]*stn[lca[j][i-1]][i-1];
}
num2[0]=inf;
for(int i=1;i<=n;++i){
num2[i]=num[lca[i][0]];
for(int j:g[i])
num2[i]=min(num2[i],num[j]);
}
while(q--){
s=read(),t=read();
int pos=v[t];
MATRIX ans,tt;
tt.n=tt.m=3;
tt.a[0][0]=tt.a[1][1]=tt.a[2][2]=0;
tt.a[0][1]=tt.a[0][2]=tt.a[1][0]=tt.a[1][2]=tt.a[2][0]=tt.a[2][1]=inf;
ans=tt;
a=s,b=t;
if(dep[a]<dep[b])swap(a,b);
for(int i=18;i>=0;--i)
if(dep[a]-(1<<i)>=dep[b])a=lca[a][i];
int LCA;
if(a==b)LCA=a;
else{
for(int i=18;i>=0;--i)
if(lca[a][i]!=lca[b][i])a=lca[a][i],b=lca[b][i];
LCA=lca[a][0];
}
for(int i=18;i>=0;--i)
if(dep[s]-(1<<i)>=dep[LCA]){
tt=tt*stn[s][i];
s=lca[s][i];
}
tt=tt*base[LCA];
t=lca[t][0];
for(int i=18;i>=0;--i)
if(dep[t]-(1<<i)>=dep[LCA]){
ans=sts[t][i]*ans;
t=lca[t][i];
}
ans=tt*ans;
cout<<ans.a[0][0]+pos<<endl;
}
return 0;
}