wa#3#8 求助
查看原帖
wa#3#8 求助
199821
LongDouble楼主2022/2/13 11:15

思路是费用流,我猜是建模细节挂了,求大佬帮忙看看/kel

#include <bits/stdc++.h>
using namespace std;
const int N = 100010, M = 1000010;
const int rd[4] = {1, 2, 4, 8};
const int dx[4] = {-1, 0, 1, 0}, dy[4] = {0, 1, 0, -1};
int n, m, s, t, tot;
int head[N], to[M], nxt[M];
int mp[2021][2021];
long long c[M], cst[M];
long long ans, ans2;

void add(int u, int v, int w, int x) {
	to[++tot] = v;
	c[tot] = w;
	cst[tot] = x;
	nxt[tot] = head[u];
	head[u] = tot;
}

void con(int u, int v, int w, int x) {
	if (!w) return;
	add(u, v, w, x);
	add(v, u, 0, -x);
}

int node(int x, int y, int k) {
	return k * n * m + (x - 1) * m + y;
}

bool ok(int x, int y) {
	return x >= 1 && x <= n && y >= 1 && y <= m;
}

int cnt(int x) {
	int ans = 0;
	while (x) {
		if (x & 1) ans++;
		x >>= 1;
	}
	return ans;
}

int vis[N], pre[N], q[N];
long long incf[N], dis[N];

bool spfa() {
	memset(vis, 0, sizeof vis);
	memset(dis, 0x3f, sizeof dis);
	int hd = 1, tl = 0;
	vis[s] = 1;
	q[++tl] = s;
	incf[s] = 2e17;
	dis[s] = 0;
	while (hd <= tl) {
		int u = q[hd]; hd++;
		vis[u] = 0;
		for (int i = head[u]; i; i = nxt[i]) 
			if (c[i] && dis[to[i]] > dis[u] + cst[i]) {
				int v = to[i];
				incf[v] = min(incf[u], c[i]);
				dis[v] = dis[u] + cst[i];
				pre[v] = i;
				if (vis[v]) continue;
				q[++tl] = v;
				vis[v] = 1;
			}
	}
	if (dis[t] > 3e12) return 0;
	return 1;
}

void update() {
	int u = t;
	while (u != s) {
		int i = pre[u];
		c[i] -= incf[t];
		c[i ^ 1] += incf[t];
		u = to[i ^ 1];
	}
	ans += incf[t];
	ans2 += incf[t] * dis[t];
}

void watch() {
	for (int i = 2; i <= tot; i += 2) 
		if (c[i ^ 1] >= 0) {
			printf("%d %d %lld %lld []\n", to[i ^ 1], to[i], c[i], c[i ^ 1] * cst[i]);
		}
}

inline int read() {
    int x = 0;
	int f = 1; char ch = getchar();
    while (ch < '0' || ch > '9') {if (ch == '-') f = -1; ch = getchar();}
    while (ch >= '0' && ch <= '9') {x = x * 10 + ch - '0'; ch = getchar();}
    return x * f;
}

int main() {
	n = read(); m = read();
	s = n * m * 5 + 1, t = n * m * 5 + 2;
	tot = 1;
	long long sum = 0, sum2 = 0;
	for (int i = 1; i <= n; i++) 
		for (int j = 1; j <= m; j++) {
			scanf("%d", &mp[i][j]);
			int tmp = cnt(mp[i][j]);
			if ((i + j) & 1) con(s, node(i, j, 4), tmp, 0), sum2 += tmp;
			else con(node(i, j, 4), t, tmp, 0), sum += tmp;
			for (int k = 0; k < 4; k++) {
				if (mp[i][j] & rd[k]) {
					if ((i + j) & 1) con(node(i, j, 4), node(i, j, k), 1, 0);
					else con(node(i, j, k), node(i, j, 4), 1, 0);
				}
				if (ok(i + dx[k], j + dy[k])) {
					//printf("(%d %d)\n", node(i, j, k), node(i + dx[k], j + dy[k], k ^ 2));
					if ((i + j) & 1) con(node(i, j, k), node(i + dx[k], j + dy[k], k ^ 2), 1, 0);
					else con(node(i + dx[k], j + dy[k], k ^ 2), node(i, j, k), 1, 0);
				}
			}
			if (tmp == 1) {
				for (int k = 0; k < 4; k++)
					if (mp[i][j] & rd[k]) {
						if ((i + j) & 1) {
							con(node(i, j, k), node(i, j, k ^ 1), 1, 1);
							con(node(i, j, k), node(i, j, k ^ 3), 1, 1);
							con(node(i, j, k), node(i, j, k ^ 2), 1, 2);
						}
						else {
							con(node(i, j, k ^ 1), node(i, j, k), 1, 1);
							con(node(i, j, k ^ 3), node(i, j, k), 1, 1);
							con(node(i, j, k ^ 2), node(i, j, k), 1, 2);
						}
					}
			}
			else if (tmp == 2) {
				for (int k = 0; k < 4; k++)
					if (mp[i][j] & rd[k]) {
						if (mp[i][j] & rd[k ^ 2]) continue;
						if ((i + j) & 1) con(node(i, j, k), node(i, j, k ^ 2), 1, 1);
						else con(node(i, j, k ^ 2), node(i, j, k), 1, 1);
					}
			}
			else if (tmp == 3) {
				for (int k = 0; k < 4; k++)
					if (!(mp[i][j] & rd[k])) {
						if ((i + j) & 1) {
							con(node(i, j, k ^ 1), node(i, j, k), 1, 1);
							con(node(i, j, k ^ 3), node(i, j, k), 1, 1);
							con(node(i, j, k ^ 2), node(i, j, k), 1, 2);
						}
						else {
							con(node(i, j, k), node(i, j, k ^ 1), 1, 1);
							con(node(i, j, k), node(i, j, k ^ 3), 1, 1);
							con(node(i, j, k), node(i, j, k ^ 2), 1, 2);
						}
					}
			}
		}	
	while (spfa()) update();
	//watch();
	//printf("%lld %lld %lld %lld\n", sum, sum2, ans, ans2);
	if (sum != sum2 || ans < sum) printf("-1\n");
	else printf("%lld\n", ans2);
	return 0;
}
2022/2/13 11:15
加载中...