样例没过,提交过了
#include<bits/stdc++.h>
using namespace std;
int n,c[100010],ans;
vector<int> a[100010];
bool used[100010];
int wmax,wpos,bmax,bpos;
struct point{
int pos,step;
};
void find(int start)
{
memset(used,0,sizeof(used));
queue<point> q;
point s;
s.pos=start,s.step=0;
q.push(s);
while(!q.empty())
{
point u=q.front();
q.pop();
for(auto i:a[u.pos])
{
point v=u;
v.pos=i,v.step+=1;
if(used[v.pos]==1) continue;
if(c[v.pos]==0)
{
if(v.step>wmax)
{
wmax=v.step;
wpos=v.pos;
}
}
else
{
if(v.step>bmax)
{
bmax=v.step;
bpos=v.pos;
}
}
used[v.pos]=1;
q.push(v);
}
}
return;
}
void bfs(int start)
{
memset(used,0,sizeof(used));
queue<point> q;
point s;
s.pos=start,s.step=0;
q.push(s);
while(!q.empty())
{
point u=q.front();
q.pop();
if(c[u.pos]!=c[start]) ans=max(ans,u.step);
for(auto i:a[u.pos])
{
point v=u;
v.pos=i,v.step+=1;
if(used[i]==1) continue;
used[i]=1;
q.push(v);
}
}
return;
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;i++)
scanf("%d",&c[i]);
for(int i=1;i<n;i++)
{
int x,y;
scanf("%d %d",&x,&y);
a[x].push_back(y);
a[y].push_back(x);
}
find(1);
bfs(wpos);
bfs(bpos);
printf("%d",ans);
return 0;
}