第一次代码
#include <bits/stdc++.h>
using namespace std;
int n;
vector<int> p[100005];
bool a[100005];
struct xxs {
int id, dp0, sw0, dp1, sw1, len;
} jd[100005];
bool b[100005];
void dfs(int x, int dp) {
b[x] = 1;
if (p[x].size() == 1) {
jd[x].id = x, jd[x].len = -1;
if (a[x])
jd[x].dp1 = dp, jd[x].sw1 = dp, jd[x].sw0 = -1, jd[x].dp0 = -1;
else
jd[x].dp0 = dp, jd[x].sw0 = dp, jd[x].sw1 = -1, jd[x].dp1 = -1;
//cout << jd[x].dp0 << " " << x << endl;
return ;
}
jd[x].sw0 = 100005, jd[x].dp0 = -1, jd[x].sw1 = 100005, jd[x].dp1 = -1;
if (a[x])
jd[x].sw1 = dp, jd[x].dp1 = dp;
else
jd[x].sw0 = dp, jd[x].dp0 = dp;
jd[x].len = -1, jd[x].id = x;
int b1 = -1, b2 = -1, s1 = -1, s2 = -1, ib1 = -1, ib2 = -1, is1 = -1, is2 = -1;
for (int i = 0; i < p[x].size(); i++) {
int son = p[x][i];
if (b[son])
continue ;
//cout << son << '\n';
dfs(son, dp + 1);
if (jd[son].dp1 > b1) {
s1 = b1, is1 = ib1;
b1 = jd[son].dp1, ib1 = son;
} else if (jd[son].dp1 > s1) {
s1 = jd[son].dp1, is1 = son;
}
if (jd[son].dp0 > b2) {
s2 = b2, is2 = ib2;
b2 = jd[son].dp0, ib2 = son;
} else if (jd[x].dp0 > s2) {
//cout << son << '\n';
//cout << jd[son].dp0;
s2 = jd[son].dp0, is2 = son;
}
jd[x].dp1 = max(jd[son].dp1, jd[x].dp1);
jd[x].dp0 = max(jd[son].dp0, jd[x].dp0);
jd[x].sw1 = min(jd[son].sw1, jd[x].sw1);
jd[x].sw0 = min(jd[son].sw0, jd[x].sw0);
jd[x].len = max(jd[son].len, jd[x].len);
}
//cout << jd[x].len << b1 << b2 << ib1 << ib2;
//cout << is1 << ' ' << is2 << '\n';
if (ib1 != ib2) {
jd[x].len = max(jd[x].len, b1 + b2 - 2 * dp);
} else if (is1 != -1 ) {
//cout << 'a';
jd[x].len = max(jd[x].len, b2 + s1 - 2 * dp);
} else if (is2 != -1) {
jd[x].len = max(jd[x].len, b1 + s2 - 2 * dp);
}
//cout << b2 << " " << s1 << dp << " " << jd[x].len << " " << x << endl;
return ;
}
int main() {
cin >> n;
for (int i = 1; i <= n; i++) {
cin >> a[i];
}
for (int i = 1; i <= n - 1; i++) {
int u, v;
cin >> u >> v;
p[u].push_back(v);
p[v].push_back(u);
}
dfs(1, 1);
cout << jd[1].len;
return 0;
}
第二次代码
#include <bits/stdc++.h>
using namespace std;
int n;
vector<int> p[100005];
bool a[100005];
struct xxs {
int id, dp0, sw0, dp1, sw1, len;
} jd[100005];
bool b[100005];
void dfs(int x, int dp) {
b[x] = 1;/*
if (p[x].size() == 1) {
jd[x].id = x, jd[x].len = -1;
if (a[x])
jd[x].dp1 = dp, jd[x].sw1 = dp, jd[x].sw0 = -1, jd[x].dp0 = -1;
else
jd[x].dp0 = dp, jd[x].sw0 = dp, jd[x].sw1 = -1, jd[x].dp1 = -1;
return ;
}*/
jd[x].sw0 = 100005, jd[x].dp0 = -1, jd[x].sw1 = 100005, jd[x].dp1 = -1;
if (a[x])
jd[x].sw1 = dp, jd[x].dp1 = dp;
else
jd[x].sw0 = dp, jd[x].dp0 = dp;
jd[x].len = -1, jd[x].id = x;
int b1 = jd[x].dp1, b2 = jd[x].dp0, s1 = -1, s2 = -1, ib1 = x, ib2 = x, is1 = -1, is2 = -1;
for (int i = 0; i < p[x].size(); i++) {
int son = p[x][i];
if (b[son])
continue ;
//cout << son << '\n';
dfs(son, dp + 1);
if (jd[son].dp1 > b1) {
s1 = b1, is1 = ib1;
b1 = jd[son].dp1, ib1 = son;
} else if (jd[son].dp1 > s1) {
s1 = jd[son].dp1, is1 = son;
}
if (jd[son].dp0 > b2) {
s2 = b2, is2 = ib2;
b2 = jd[son].dp0, ib2 = son;
} else if (jd[x].dp0 > s2) {
s2 = jd[son].dp0, is2 = son;
}
jd[x].dp1 = max(jd[son].dp1, jd[x].dp1);
jd[x].dp0 = max(jd[son].dp0, jd[x].dp0);
jd[x].sw1 = min(jd[son].sw1, jd[x].sw1);
jd[x].sw0 = min(jd[son].sw0, jd[x].sw0);
jd[x].len = max(jd[son].len, jd[x].len);
}
//cout << b1 << b2 << s1 << s2 << '\n';
if (ib1 != ib2) {
jd[x].len = max(jd[x].len, b1 + b2 - 2 * dp);
} else if (is1 != -1 ) {
jd[x].len = max(jd[x].len, b2 + s1 - 2 * dp);
} else if (is2 != -1) {
jd[x].len = max(jd[x].len, b1 + s2 - 2 * dp);
}
return ;
}
int main() {
cin >> n;
for (int i = 1; i <= n; i++) {
cin >> a[i];
}
for (int i = 1; i <= n - 1; i++) {
int u, v;
cin >> u >> v;
p[u].push_back(v);
p[v].push_back(u);
}
dfs(1, 1);
cout << jd[1].len;
return 0;
}
/*
3
0 1 0
1 2
2 3
*/
求hack数据。
第一次对的数据点在第二次错了
[第二次](#include <bits/stdc++.h> using namespace std;
int n; vector p[100005]; bool a[100005];
struct xxs { int id, dp0, sw0, dp1, sw1, len; } jd[100005]; bool b[100005];
void dfs(int x, int dp) { b[x] = 1;/* if (p[x].size() == 1) { jd[x].id = x, jd[x].len = -1; if (a[x]) jd[x].dp1 = dp, jd[x].sw1 = dp, jd[x].sw0 = -1, jd[x].dp0 = -1; else jd[x].dp0 = dp, jd[x].sw0 = dp, jd[x].sw1 = -1, jd[x].dp1 = -1; return ;
}*/
jd[x].sw0 = 100005, jd[x].dp0 = -1, jd[x].sw1 = 100005, jd[x].dp1 = -1;
if (a[x])
jd[x].sw1 = dp, jd[x].dp1 = dp;
else
jd[x].sw0 = dp, jd[x].dp0 = dp;
jd[x].len = -1, jd[x].id = x;
int b1 = jd[x].dp1, b2 = jd[x].dp0, s1 = -1, s2 = -1, ib1 = x, ib2 = x, is1 = -1, is2 = -1;
for (int i = 0; i < p[x].size(); i++) {
int son = p[x][i];
if (b[son])
continue ;
//cout << son << '\n';
dfs(son, dp + 1);
if (jd[son].dp1 > b1) {
s1 = b1, is1 = ib1;
b1 = jd[son].dp1, ib1 = son;
} else if (jd[son].dp1 > s1) {
s1 = jd[son].dp1, is1 = son;
}
if (jd[son].dp0 > b2) {
s2 = b2, is2 = ib2;
b2 = jd[son].dp0, ib2 = son;
} else if (jd[x].dp0 > s2) {
s2 = jd[son].dp0, is2 = son;
}
jd[x].dp1 = max(jd[son].dp1, jd[x].dp1);
jd[x].dp0 = max(jd[son].dp0, jd[x].dp0);
jd[x].sw1 = min(jd[son].sw1, jd[x].sw1);
jd[x].sw0 = min(jd[son].sw0, jd[x].sw0);
jd[x].len = max(jd[son].len, jd[x].len);
}
//cout << b1 << b2 << s1 << s2 << '\n';
if (ib1 != ib2) {
jd[x].len = max(jd[x].len, b1 + b2 - 2 * dp);
} else if (is1 != -1 ) {
jd[x].len = max(jd[x].len, b2 + s1 - 2 * dp);
} else if (is2 != -1) {
jd[x].len = max(jd[x].len, b1 + s2 - 2 * dp);
}
return ;
}
int main() { cin >> n;
for (int i = 1; i <= n; i++) {
cin >> a[i];
}
for (int i = 1; i <= n - 1; i++) {
int u, v;
cin >> u >> v;
p[u].push_back(v);
p[v].push_back(u);
}
dfs(1, 1);
cout << jd[1].len;
return 0;
} /* 3 0 1 0 1 2 2 3 */)