不知道为什么 WA 第 3,和第 8 个点
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int N = 1010;
const int INF = 1e9 + 10;
int n, m, a[N][N], b[N][N], pl[N * N], pos[N * N], vad[N], f, st[N * N], ans, c[N];
vector<int> com, gh[N];
bool vis[N];
int lowbit(int x) {
return x & -x;
}
void UP(int x, int p) {
for(; x <= m; x += lowbit(x)) {
c[x] = max(c[x], p);
}
}
int Query(int x) {
int ans = 0;
for(; x; x -= lowbit(x)) {
ans = max(ans, c[x]);
}
return ans;
}
void dfs(int x) {
vis[x] = 1;
com.push_back(x);
for (int y : gh[x]) {
if (!vis[y]) {
dfs(y);
}
}
}
int P(int id) {
int cnt = 0;
for (int i = 1; i <= m; i++) {
if (a[id][i] > 0) {
pl[a[id][i]] = cnt++;
}
}
fill(c, c + m + 1, 0);
for (int i = 1; i <= m; i++) {
if (b[id][i] > 0) {
UP(pl[b[id][i]] + 1, Query(pl[b[id][i]]) + 1);
}
}
int ans = 0;
for (int i = 1; i <= cnt; i++) {
ans = max(ans, Query(i));
}
return ans;
}
void Solve(vector<int>& v) {
bool f2 = 0;
for (int id : v) {
for (int j = 1; j <= m; j++) {
if (a[id][j] == 0) {
f2 = 1;
}
}
}
int cnt = 0;
for (int id : v) {
for (int j = 1; j <= m; j++) {
if (a[id][j] > 0) {
st[a[id][j]]++;
cnt++;
}
if (b[id][j] > 0) {
st[b[id][j]]++;
}
}
int tmp1[N], tmp2[N];
for (int j = 1; j <= m; j++) {
tmp1[j] = a[id][j];
tmp2[j] = b[id][j];
}
for (int j = 1; j <= m; j++) {
if (a[id][j] > 0 && st[a[id][j]] < 2) {
a[id][j] = 0;
}
if (b[id][j] > 0 && st[b[id][j]] < 2) {
b[id][j] = 0;
}
}
ans -= P(id);
for (int j = 1; j <= m; j++) {
a[id][j] = tmp1[j];
b[id][j] = tmp2[j];
}
for (int j = 1; j <= m; j++) {
if (a[id][j] > 0) {
st[a[id][j]]--;
}
if (b[id][j] > 0) {
st[b[id][j]]--;
}
}
}
ans += cnt;
if (!f2) {
if (!f) {
cout << "-1";
exit(0);
} else {
ans++;
}
}
}
int main() {
ios::sync_with_stdio(0), cin.tie(0);
cin >> n >> m;
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= m; j++) {
cin >> a[i][j];
}
}
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= m; j++) {
cin >> b[i][j];
if (a[i][j] == 0) {
f = 1;
}
}
}
for (int i = 1; i <= n; i++) {
int f = 1, cnt = 0;
for (int j = 1; j <= m; j++) {
if (a[i][j] > 0) {
st[a[i][j]] = 1;
cnt++;
}
}
for (int j = 1; j <= m; j++) {
if (b[i][j] > 0 && !st[b[i][j]]) {
f = 0;
} else if (b[i][j] > 0) {
st[b[i][j]] = 0;
}
}
for (int j = 1; j <= m; j++) {
if (st[a[i][j]]) {
f = 0;
st[a[i][j]] = 0;
}
}
if (f) {
vad[i] = 1;
int tmp = cnt - P(i);
ans += tmp;
if (cnt == m && tmp != 0) {
if (!f) {
cout << "-1";
return 0;
} else {
ans++;
}
}
}
}
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= m; j++) {
if (a[i][j] > 0 && !vad[i]) {
pos[a[i][j]] = i;
}
}
}
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= m; j++) {
if (b[i][j] > 0 && !vad[i] && pos[b[i][j]] != i) {
gh[pos[b[i][j]]].push_back(i);
gh[i].push_back(pos[b[i][j]]);
}
}
}
for (int i = 1; i <= n; i++) {
if (!vad[i] && !vis[i]) {
com.clear();
dfs(i);
Solve(com);
}
}
cout << ans;
return 0;
}