请求 Hack
查看原帖
请求 Hack
502658
Ray662楼主2024/10/27 21:43

WA 60pts。和前 22 篇题解拍了 n=1000n = 1000 的数据根本拍不出来。

#include <bits/stdc++.h>
#define _for(i, a, b)  for (int i = (a); i <= (b); i ++ )
#define _all(i, a, b)  for (int i = (a); i >= (b); i -- )
#define ull unsigned long long
#define swap(x, y)  if (x ^ y)  x ^= y ^= x ^= y
using namespace std;
const int N = 1005, M = 2005, B1 = 998244353, B2 = 917120411, V = 1e9 + 1;
int n, m, nn, mm, ans, a[N][N], b[M][M]; ull pw1[M], pw2[M], H[4][M][M];
inline ull calc(int o, int x, int y, int X, int Y) { return H[o][X][Y] - H[o][x - 1][Y] * pw1[X - x + 1] - H[o][X][y - 1] * pw2[Y - y + 1] + H[o][x - 1][y - 1] * pw1[X - x + 1] * pw2[Y - y + 1]; }
int main() {
	freopen("data.in", "r", stdin);
	freopen("temp.out", "w", stdout);
	ios :: sync_with_stdio(false), cin.tie(0), cout.tie(0);
	cin >> n >> m, pw1[0] = pw2[0] = 1, nn = n * 2 - 1, mm = m * 2 - 1; int l, r, mid, res, x, y, X, Y;
	_for (i, 1, nn)  pw1[i] = pw1[i - 1] * B1;
	_for (i, 1, mm)  pw2[i] = pw2[i - 1] * B2;
	_for (i, 1, n)  _for (j, 1, m)  cin >> a[i][j];
	_for (i, 1, n)  _for (j, 1, m)  b[(i << 1) - 1][(j << 1) - 1] = a[i][j], b[(i << 1) - 1][j << 1] = V;
	_for (i, 1, n)  _for (j, 1, mm)  b[i << 1][j] = V;
	_for (i, 1, nn)  _for (j, 1, mm)  H[1][i][j] = H[1][i - 1][j] * B1 + H[1][i][j - 1] * B2 - H[1][i - 1][j - 1] * B1 * B2 + b[i][j];
	_for (i, 1, nn >> 1)  _for (j, 1, mm)  swap(b[i][j], b[nn - i + 1][j]);
	_for (i, 1, nn)  _for (j, 1, mm)  H[2][i][j] = H[2][i - 1][j] * B1 + H[2][i][j - 1] * B2 - H[2][i - 1][j - 1] * B1 * B2 + b[i][j];
	_for (i, 1, nn >> 1)  _for (j, 1, mm)  swap(b[i][j], b[nn - i + 1][j]);
	_for (j, 1, mm >> 1)  _for (i, 1, nn)  swap(b[i][j], b[i][mm - j + 1]);
	_for (i, 1, nn)  _for (j, 1, mm)  H[3][i][j] = H[3][i - 1][j] * B1 + H[3][i][j - 1] * B2 - H[3][i - 1][j - 1] * B1 * B2 + b[i][j];
	_for (i, 1, nn)  _for (j, 1, mm)  if (! ((i + j) & 1)) {
		l = 1, r = min(min(i, nn - i + 1), min(j, mm - j + 1)), res = 0;
		while (l <= r) {
			mid = (l + r) >> 1, x = i - mid + 1, y = j - mid + 1, X = i + mid - 1, Y = j + mid - 1;
			if (calc(1, x, y, X, Y) == calc(2, nn - X + 1, y, nn - x + 1, Y) && calc(1, x, y, X, Y) == calc(3, x, mm - Y + 1, X, mm - y + 1))  res = mid, l = mid + 1;
			else  r = mid - 1;
		}
		ans += (res >> 1);
	}
	if (n > 1 && m > 1)  ans += (n << 1) + (m << 1) - 4;
	else  ans += n * m;
	cout << ans << "\n";
	return 0;
}
2024/10/27 21:43
加载中...