思路是费用流,我猜是建模细节挂了,求大佬帮忙看看/kel
#include <bits/stdc++.h>
using namespace std;
const int N = 100010, M = 1000010;
const int rd[4] = {1, 2, 4, 8};
const int dx[4] = {-1, 0, 1, 0}, dy[4] = {0, 1, 0, -1};
int n, m, s, t, tot;
int head[N], to[M], nxt[M];
int mp[2021][2021];
long long c[M], cst[M];
long long ans, ans2;
void add(int u, int v, int w, int x) {
to[++tot] = v;
c[tot] = w;
cst[tot] = x;
nxt[tot] = head[u];
head[u] = tot;
}
void con(int u, int v, int w, int x) {
if (!w) return;
add(u, v, w, x);
add(v, u, 0, -x);
}
int node(int x, int y, int k) {
return k * n * m + (x - 1) * m + y;
}
bool ok(int x, int y) {
return x >= 1 && x <= n && y >= 1 && y <= m;
}
int cnt(int x) {
int ans = 0;
while (x) {
if (x & 1) ans++;
x >>= 1;
}
return ans;
}
int vis[N], pre[N], q[N];
long long incf[N], dis[N];
bool spfa() {
memset(vis, 0, sizeof vis);
memset(dis, 0x3f, sizeof dis);
int hd = 1, tl = 0;
vis[s] = 1;
q[++tl] = s;
incf[s] = 2e17;
dis[s] = 0;
while (hd <= tl) {
int u = q[hd]; hd++;
vis[u] = 0;
for (int i = head[u]; i; i = nxt[i])
if (c[i] && dis[to[i]] > dis[u] + cst[i]) {
int v = to[i];
incf[v] = min(incf[u], c[i]);
dis[v] = dis[u] + cst[i];
pre[v] = i;
if (vis[v]) continue;
q[++tl] = v;
vis[v] = 1;
}
}
if (dis[t] > 3e12) return 0;
return 1;
}
void update() {
int u = t;
while (u != s) {
int i = pre[u];
c[i] -= incf[t];
c[i ^ 1] += incf[t];
u = to[i ^ 1];
}
ans += incf[t];
ans2 += incf[t] * dis[t];
}
void watch() {
for (int i = 2; i <= tot; i += 2)
if (c[i ^ 1] >= 0) {
printf("%d %d %lld %lld []\n", to[i ^ 1], to[i], c[i], c[i ^ 1] * cst[i]);
}
}
inline int read() {
int x = 0;
int f = 1; char ch = getchar();
while (ch < '0' || ch > '9') {if (ch == '-') f = -1; ch = getchar();}
while (ch >= '0' && ch <= '9') {x = x * 10 + ch - '0'; ch = getchar();}
return x * f;
}
int main() {
n = read(); m = read();
s = n * m * 5 + 1, t = n * m * 5 + 2;
tot = 1;
long long sum = 0, sum2 = 0;
for (int i = 1; i <= n; i++)
for (int j = 1; j <= m; j++) {
scanf("%d", &mp[i][j]);
int tmp = cnt(mp[i][j]);
if ((i + j) & 1) con(s, node(i, j, 4), tmp, 0), sum2 += tmp;
else con(node(i, j, 4), t, tmp, 0), sum += tmp;
for (int k = 0; k < 4; k++) {
if (mp[i][j] & rd[k]) {
if ((i + j) & 1) con(node(i, j, 4), node(i, j, k), 1, 0);
else con(node(i, j, k), node(i, j, 4), 1, 0);
}
if (ok(i + dx[k], j + dy[k])) {
//printf("(%d %d)\n", node(i, j, k), node(i + dx[k], j + dy[k], k ^ 2));
if ((i + j) & 1) con(node(i, j, k), node(i + dx[k], j + dy[k], k ^ 2), 1, 0);
else con(node(i + dx[k], j + dy[k], k ^ 2), node(i, j, k), 1, 0);
}
}
if (tmp == 1) {
for (int k = 0; k < 4; k++)
if (mp[i][j] & rd[k]) {
if ((i + j) & 1) {
con(node(i, j, k), node(i, j, k ^ 1), 1, 1);
con(node(i, j, k), node(i, j, k ^ 3), 1, 1);
con(node(i, j, k), node(i, j, k ^ 2), 1, 2);
}
else {
con(node(i, j, k ^ 1), node(i, j, k), 1, 1);
con(node(i, j, k ^ 3), node(i, j, k), 1, 1);
con(node(i, j, k ^ 2), node(i, j, k), 1, 2);
}
}
}
else if (tmp == 2) {
for (int k = 0; k < 4; k++)
if (mp[i][j] & rd[k]) {
if (mp[i][j] & rd[k ^ 2]) continue;
if ((i + j) & 1) con(node(i, j, k), node(i, j, k ^ 2), 1, 1);
else con(node(i, j, k ^ 2), node(i, j, k), 1, 1);
}
}
else if (tmp == 3) {
for (int k = 0; k < 4; k++)
if (!(mp[i][j] & rd[k])) {
if ((i + j) & 1) {
con(node(i, j, k ^ 1), node(i, j, k), 1, 1);
con(node(i, j, k ^ 3), node(i, j, k), 1, 1);
con(node(i, j, k ^ 2), node(i, j, k), 1, 2);
}
else {
con(node(i, j, k), node(i, j, k ^ 1), 1, 1);
con(node(i, j, k), node(i, j, k ^ 3), 1, 1);
con(node(i, j, k), node(i, j, k ^ 2), 1, 2);
}
}
}
}
while (spfa()) update();
//watch();
//printf("%lld %lld %lld %lld\n", sum, sum2, ans, ans2);
if (sum != sum2 || ans < sum) printf("-1\n");
else printf("%lld\n", ans2);
return 0;
}