#include <bits/stdc++.h>
using namespace std;
int n,a[100005],m,f[100005];
int l=0;
struct node{
int a=0,b=0,q=0;
}p[100005];
void dfs(int c,int s){
int z=0;
for(int i=f[c];i!=0;i=p[i].q){
if(a[p[i].a]!=a[p[i].b]){
z=1;
dfs(p[i].b,s+1);
}
}
if(z==0){
l+=s;
}
}
int main(){
scanf("%d",&n);
m=n-1;
for(int i=1;i<=n;i++){
scanf("%d",&a[i]);
}
for(int i=1;i<=m;i++){
scanf("%d%d",&p[i].a,&p[i].b);
p[i].q=f[p[i].a];
f[p[i].a]=i;
}
int ans=0;
for(int i=1;i<=n;i++){
l=0;
dfs(i,1);
ans=max(ans,l);
}
cout<<ans;
return 0;
}