萌新刚学OI,头插DP求调
查看原帖
萌新刚学OI,头插DP求调
262147
Reobrok_Kk楼主2024/10/3 20:53

RT

思路是对当前的状态中,当前规划到的格子左边及右边是否选取进行分类讨论,更新状态。丢进hash表时若状态联通块数是1,统计答案。

不知道为何写挂了。。。

求各路神仙们调试!

如果能给出 hack 也行!十分感谢!!!


#include <bits/stdc++.h>
#define int long long
#define debug(x) cout << (x) << " ";
#define nxt cout << endl;
using namespace std;
const int N = 11;
const int Mod = 114514;
int mp[N][N], ans = -INT_MAX;
typedef long long ll;
int get(int s, int k) { return (s >> ((k - 1) * 3)) & 7;}
void update(int &s, int k, int val) {
    s ^= get(s, k) << ((k - 1) * 3);
    s ^= val << ((k - 1) * 3);
}
struct Hash {
    int hsh[Mod], siz, tru[Mod]; int maxx[Mod];
    // Hash() { memset(maxx, -0x3f, sizeof maxx);}
    void clear() {
        memset(hsh, -1, sizeof hsh);
        memset(tru, -1, sizeof tru);
        memset(maxx, -0x3f, sizeof maxx);
        // for (int i = 1; i <= siz; ++ i) tru[i] = -1, maxx[i] = -0x3f3f3f3f;
        siz = 0;
    }
    void add(int zt, int value) {
        // if (zt == 73) {debug("???\n");nxt;debug(maxx[zt]);debug(zt); debug(value); nxt;
        // for (int l = 1; l <= 9; ++ l) debug(get(zt, l));
        // nxt;}
        int z[N], p[N] = {0}, s = 0, cnt = 0;
        for (int i = 1; i <= 9; ++ i) z[i] = get(zt, i);
        for (int i = 1; i <= 9; ++ i)
            if (z[i] && !p[i]) {
                ++ cnt;
                for (int j = i; j <= 9; ++ j)
                    if (z[j] == z[i]) p[j] = cnt;
            }
        // for (int i = 1; i <= 9; ++ i) debug(p[i]);
        // nxt;
        if (cnt == 1) ans = max(ans, value);
        zt = 0;
        for (int i  = 1; i <= 9; ++ i) update(zt, i, p[i]);
        s = zt % Mod;
        while (hsh[s] != -1 && tru[hsh[s]] != zt) s ++, s %= Mod;
        if (hsh[s] == -1) { hsh[s] = ++ siz; tru[siz] = zt; maxx[siz] = value;}
        else {maxx[hsh[s]] = max(maxx[hsh[s]], value);} //if (zt == 73){debug("bur\n");debug(siz)};}
        // if (zt == 73) {debug(maxx[hsh[s]]); nxt;}
    }
}f[2];
signed main() {
    int n; cin >> n;
    for (int i = 1; i <= n; ++ i)
        for (int j = 1; j <= n; ++ j)
            { cin >> mp[i][j]; ans = max(ans, mp[i][j]);}
    int curr = 1;
    f[0].clear();
    f[1].clear();
    f[0].add(0, 0);
    for (int i = 1; i <= n; ++ i) {
        for (int j = 1; j <= n; ++ j) {
            for (int k = 1; k <= f[curr ^ 1].siz; ++ k) {
                int s = f[curr ^ 1].tru[k], t = s;
                int plug1 = get(s, j), plug2 = get(s, j + 1);
                // debug(i); debug(j); debug(s); debug(plug1); debug(plug2); debug(f[curr ^ 1].maxx[k]); nxt;
                // for (int l = 1; l <= 9; ++ l) debug(get(s, l));
                // nxt;
                if (!plug1 && !plug2) {
                    f[curr].add(s, f[curr ^ 1].maxx[k]);
                    update(t, j, 7); if (j != n) update(t, j + 1, 7);
                    f[curr].add(t, f[curr ^ 1].maxx[k] + mp[i][j]);
                } else if (plug1 && !plug2) {
                    update(t, j, 0);
                    f[curr].add(t, f[curr ^ 1].maxx[k]);
                    update(t, j, plug1); if (j != n) update(t, j + 1, plug1);
                    f[curr].add(t, f[curr ^ 1].maxx[k] + mp[i][j]);
                } else if (!plug1 && plug2) {
                    update(t, j, plug2); if (j == n) update(t, j + 1, 0);
                    f[curr].add(t, f[curr ^ 1].maxx[k] + mp[i][j]);
                    int cnt = 0;
                    for (int l = 1; l <= n + 1; ++ l) cnt += (l != j && l != j + 1 && get(s, l) == plug2);
                    if (cnt) {
                        update(t, j, 0); update(t, j + 1, 0);
                        f[curr].add(t, f[curr ^ 1].maxx[k]);
                    }
                } else {
                    int cnt = 0;
                    for (int l = 1; l <= n + 1; ++ l) cnt += (l != j && l != j + 1 && get(s, l) == plug2);
                    if (cnt) {
                        update(t, j, 0); update(t, j + 1, 0);
                        f[curr].add(t, f[curr ^ 1].maxx[k]);
                    }
                    for (int l = 1; l <= n + 1; ++ l)  if (get(s, l) == plug2) update(t, l, plug1);
                    update(t, j, plug1); if (j == n) update(t, j + 1, 0); else update(t, j + 1, plug1);
                    f[curr].add(t, f[curr ^ 1].maxx[k] + mp[i][j]);
                }
            }
            curr ^= 1;
            f[curr].clear();
        }
        if (i != n) for (int k = 1; k <= f[curr ^ 1].siz; ++ k) f[curr ^ 1].tru[k] <<= 3;
    }
    cout << ans << "\n";
}
2024/10/3 20:53
加载中...