萌新悬关求调
查看原帖
萌新悬关求调
746775
whznf楼主2025/7/13 12:18

RT,思路是第一篇题解的。目前只写了k=3的部分,但测大样例时50000个询问错3个。小样例通过了。 万分感谢。

#include <bits/stdc++.h>
#define int long long
using namespace std;

const int N=200010;
int n,q,k,v[N],num[N],a,b,s,t,dep[N],lca[N][19],num2[N];
vector<int>g[N];

inline int read(){
	int s=0;
	bool f=0;
	char c=getchar();
	while(!isdigit(c)){
		if(c=='-')f^=1;
		c=getchar();
	}
	while(isdigit(c)){
		s=(s<<1)+(s<<3)+(c^48);
		c=getchar();
	}
	return f==0?s:-s;
}

struct MATRIX{
	int n,m,a[3][3];
	MATRIX operator*(MATRIX b){
		MATRIX c;
		c.n=n,c.m=b.m;
		memset(c.a,63,sizeof(c.a));
		for(int i=0;i<c.n;++i)
			for(int j=0;j<c.m;++j)
				for(int k=0;k<m;++k)
					c.a[i][j]=min(c.a[i][j],a[i][k]+b.a[k][j]);
		return c;
	}
}base[N],stn[N][19],sts[N][19];

void dfs(int r,int f){
	lca[r][0]=f;
	dep[r]=dep[f]+1;
	for(int i:g[r])
		if(i!=f)dfs(i,r);
}

signed main(){
//	freopen("test.in","r",stdin);
//	freopen("test.out","w",stdout);
	n=read(),q=read(),k=read();
	memset(num,63,sizeof(num));
	const int inf=num[0];
	for(int i=1;i<=n;++i)v[i]=read();
	for(int i=1;i<n;++i){
		a=read(),b=read();
		g[a].push_back(b);
		g[b].push_back(a);
		num[a]=min(num[a],v[b]);
		num[b]=min(v[a],num[b]);
	}
	for(int i=1;i<=n;++i){
		base[i].n=base[i].m=3;
		base[i].a[0][0]=v[i];
		base[i].a[0][1]=v[i];
		base[i].a[0][2]=v[i];
		base[i].a[1][0]=0;
		base[i].a[1][1]=num[i];
		base[i].a[1][2]=inf;
		base[i].a[2][0]=inf;
		base[i].a[2][1]=0;
		base[i].a[2][2]=inf;
	}
	for(int i=1;i<=n;++i)stn[i][0]=sts[i][0]=base[i];
	dfs(1,0);
	for(int i=1;i<19;++i)
		for(int j=1;j<=n;++j){
			lca[j][i]=lca[lca[j][i-1]][i-1];
			sts[j][i]=sts[lca[j][i-1]][i-1]*sts[j][i-1];
			stn[j][i]=stn[j][i-1]*stn[lca[j][i-1]][i-1];
		}
	num2[0]=inf;
	for(int i=1;i<=n;++i){
		num2[i]=num[lca[i][0]];
		for(int j:g[i])
			num2[i]=min(num2[i],num[j]);
	}
	while(q--){
		s=read(),t=read();
		int pos=v[t];
		MATRIX ans,tt;
		tt.n=tt.m=3;
		tt.a[0][0]=tt.a[1][1]=tt.a[2][2]=0;
		tt.a[0][1]=tt.a[0][2]=tt.a[1][0]=tt.a[1][2]=tt.a[2][0]=tt.a[2][1]=inf;
		ans=tt;
		a=s,b=t;
		if(dep[a]<dep[b])swap(a,b);
		for(int i=18;i>=0;--i)
			if(dep[a]-(1<<i)>=dep[b])a=lca[a][i];
		int LCA;
		if(a==b)LCA=a;
		else{
			for(int i=18;i>=0;--i)
				if(lca[a][i]!=lca[b][i])a=lca[a][i],b=lca[b][i];
			LCA=lca[a][0];
		}
		for(int i=18;i>=0;--i)
			if(dep[s]-(1<<i)>=dep[LCA]){
				tt=tt*stn[s][i];
				s=lca[s][i];
			}
		tt=tt*base[LCA];
		t=lca[t][0];
		for(int i=18;i>=0;--i)
			if(dep[t]-(1<<i)>=dep[LCA]){
				ans=sts[t][i]*ans;
				t=lca[t][i];
			}
		ans=tt*ans;
		cout<<ans.a[0][0]+pos<<endl;
	}
	return 0;
}

2025/7/13 12:18
加载中...