WA 60pts。和前 2 篇题解拍了 n=1000 的数据根本拍不出来。
#include <bits/stdc++.h>
#define _for(i, a, b) for (int i = (a); i <= (b); i ++ )
#define _all(i, a, b) for (int i = (a); i >= (b); i -- )
#define ull unsigned long long
#define swap(x, y) if (x ^ y) x ^= y ^= x ^= y
using namespace std;
const int N = 1005, M = 2005, B1 = 998244353, B2 = 917120411, V = 1e9 + 1;
int n, m, nn, mm, ans, a[N][N], b[M][M]; ull pw1[M], pw2[M], H[4][M][M];
inline ull calc(int o, int x, int y, int X, int Y) { return H[o][X][Y] - H[o][x - 1][Y] * pw1[X - x + 1] - H[o][X][y - 1] * pw2[Y - y + 1] + H[o][x - 1][y - 1] * pw1[X - x + 1] * pw2[Y - y + 1]; }
int main() {
freopen("data.in", "r", stdin);
freopen("temp.out", "w", stdout);
ios :: sync_with_stdio(false), cin.tie(0), cout.tie(0);
cin >> n >> m, pw1[0] = pw2[0] = 1, nn = n * 2 - 1, mm = m * 2 - 1; int l, r, mid, res, x, y, X, Y;
_for (i, 1, nn) pw1[i] = pw1[i - 1] * B1;
_for (i, 1, mm) pw2[i] = pw2[i - 1] * B2;
_for (i, 1, n) _for (j, 1, m) cin >> a[i][j];
_for (i, 1, n) _for (j, 1, m) b[(i << 1) - 1][(j << 1) - 1] = a[i][j], b[(i << 1) - 1][j << 1] = V;
_for (i, 1, n) _for (j, 1, mm) b[i << 1][j] = V;
_for (i, 1, nn) _for (j, 1, mm) H[1][i][j] = H[1][i - 1][j] * B1 + H[1][i][j - 1] * B2 - H[1][i - 1][j - 1] * B1 * B2 + b[i][j];
_for (i, 1, nn >> 1) _for (j, 1, mm) swap(b[i][j], b[nn - i + 1][j]);
_for (i, 1, nn) _for (j, 1, mm) H[2][i][j] = H[2][i - 1][j] * B1 + H[2][i][j - 1] * B2 - H[2][i - 1][j - 1] * B1 * B2 + b[i][j];
_for (i, 1, nn >> 1) _for (j, 1, mm) swap(b[i][j], b[nn - i + 1][j]);
_for (j, 1, mm >> 1) _for (i, 1, nn) swap(b[i][j], b[i][mm - j + 1]);
_for (i, 1, nn) _for (j, 1, mm) H[3][i][j] = H[3][i - 1][j] * B1 + H[3][i][j - 1] * B2 - H[3][i - 1][j - 1] * B1 * B2 + b[i][j];
_for (i, 1, nn) _for (j, 1, mm) if (! ((i + j) & 1)) {
l = 1, r = min(min(i, nn - i + 1), min(j, mm - j + 1)), res = 0;
while (l <= r) {
mid = (l + r) >> 1, x = i - mid + 1, y = j - mid + 1, X = i + mid - 1, Y = j + mid - 1;
if (calc(1, x, y, X, Y) == calc(2, nn - X + 1, y, nn - x + 1, Y) && calc(1, x, y, X, Y) == calc(3, x, mm - Y + 1, X, mm - y + 1)) res = mid, l = mid + 1;
else r = mid - 1;
}
ans += (res >> 1);
}
if (n > 1 && m > 1) ans += (n << 1) + (m << 1) - 4;
else ans += n * m;
cout << ans << "\n";
return 0;
}