我这么做的,大佬们帮忙看看这样可以吗?
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const ll N=1e5+10;
ll n,a[N],dep[N],dep2[N],k;
vector<ll> v[N];
void dfs(ll now,ll pre){
dep[now]=dep[pre]+a[now];
if(dep[now]>k) return ;
dep2[now]=dep2[pre]+1;
for(auto x:v[now]){
if(x==pre) continue ;
dfs(x,now);
}
return ;
}
void solve(){
cin >> n >> k;
ll sum=0;
for(int i=1;i<=n;i++){
cin >> a[i];
sum+=a[i];
v[i].clear();
}
for(int i=1;i<n;i++){
ll x,y;
cin >> x >> y;
v[x].push_back(y);
v[y].push_back(x);
}
dfs(1,1);
ll id=0,maxx=INT_MIN;
for(int i=1;i<=n;i++){
if(maxx<dep2[i]){
maxx=dep2[i];
id=i;
}
}
dep2[id]=0;
dep[id]=0;
dfs(id,id);
maxx=INT_MIN;
for(int i=1;i<=n;i++){
maxx=max(maxx,dep2[i]);
}
cout << maxx << endl;
return ;
}
ll T=1;
int main(){
while(T--) solve();
return 0;
}