求调qwq
查看原帖
求调qwq
1050375
czhusi楼主2025/7/24 20:39

不知道为什么 WA 第 33,和第 88 个点

#include <bits/stdc++.h>

using namespace std;
using ll = long long;

const int N = 1010;
const int INF = 1e9 + 10;   

int n, m, a[N][N], b[N][N], pl[N * N], pos[N * N], vad[N], f, st[N * N], ans, c[N];

vector<int> com, gh[N];

bool vis[N];

int lowbit(int x) {
    return x & -x;
}

void UP(int x, int p) {
    for(; x <= m; x += lowbit(x)) {
        c[x] = max(c[x], p);
    }
}

int Query(int x) {
    int ans = 0;
    for(; x; x -= lowbit(x)) {
        ans = max(ans, c[x]);
    }
    return ans;
}

void dfs(int x) {
    vis[x] = 1;
    com.push_back(x);
    for (int y : gh[x]) {
        if (!vis[y]) {
            dfs(y);
        }
    }
}

int P(int id) {
    int cnt = 0;
    for (int i = 1; i <= m; i++) {
        if (a[id][i] > 0) {
            pl[a[id][i]] = cnt++;
        }
    }

    fill(c, c + m + 1, 0);
    for (int i = 1; i <= m; i++) {
        if (b[id][i] > 0) {
            UP(pl[b[id][i]] + 1, Query(pl[b[id][i]]) + 1);
        }
    }

    int ans = 0;
    for (int i = 1; i <= cnt; i++) {
        ans = max(ans, Query(i));
    }
    return ans;
}

void Solve(vector<int>& v) {
    bool f2 = 0;
    for (int id : v) {
        for (int j = 1; j <= m; j++) {
            if (a[id][j] == 0) {
                f2 = 1;
            }
        }
    }

    int cnt = 0;
    for (int id : v) {
        for (int j = 1; j <= m; j++) {
            if (a[id][j] > 0) {
                st[a[id][j]]++;
                cnt++;
            }
            if (b[id][j] > 0) {
                st[b[id][j]]++;
            }
        }
        int tmp1[N], tmp2[N];
        for (int j = 1; j <= m; j++) {
            tmp1[j] = a[id][j];
            tmp2[j] = b[id][j];
        }

        for (int j = 1; j <= m; j++) {
            if (a[id][j] > 0 && st[a[id][j]] < 2) {
                a[id][j] = 0;
            }
            if (b[id][j] > 0 && st[b[id][j]] < 2) {
                b[id][j] = 0;
            }
        }
        ans -= P(id);
        for (int j = 1; j <= m; j++) {
            a[id][j] = tmp1[j];
            b[id][j] = tmp2[j];
        }
        for (int j = 1; j <= m; j++) {
            if (a[id][j] > 0) {
                st[a[id][j]]--;
            }
            if (b[id][j] > 0) {
                st[b[id][j]]--;
            }
        }
    }
    ans += cnt;
    if (!f2) {
        if (!f) {
            cout << "-1";
            exit(0);
        } else {
            ans++;
        }
    }
}

int main() {
    ios::sync_with_stdio(0), cin.tie(0);
    cin >> n >> m;
    for (int i = 1; i <= n; i++) {
        for (int j = 1; j <= m; j++) {
            cin >> a[i][j];
        }
    }
    for (int i = 1; i <= n; i++) {
        for (int j = 1; j <= m; j++) {
            cin >> b[i][j];
            if (a[i][j] == 0) {
                f = 1;
            }
        }
    }
    for (int i = 1; i <= n; i++) {
        int f = 1, cnt = 0;
        for (int j = 1; j <= m; j++) {
            if (a[i][j] > 0) {
                st[a[i][j]] = 1;
                cnt++;
            }
        }
        for (int j = 1; j <= m; j++) {
            if (b[i][j] > 0 && !st[b[i][j]]) {
                f = 0;
            } else if (b[i][j] > 0) {
                st[b[i][j]] = 0;
            }
        }
        for (int j = 1; j <= m; j++) {
            if (st[a[i][j]]) {
                f = 0;
                st[a[i][j]] = 0;
            }
        }
        if (f) {
            vad[i] = 1;
            int tmp = cnt - P(i);
            ans += tmp;
            if (cnt == m && tmp != 0) {
                if (!f) {
                    cout << "-1";
                    return 0;
                } else {
                    ans++;
                }
            }
        }
    }

    for (int i = 1; i <= n; i++) {
        for (int j = 1; j <= m; j++) {
            if (a[i][j] > 0 && !vad[i]) {
                pos[a[i][j]] = i;
            }
        }
    }

    for (int i = 1; i <= n; i++) {
        for (int j = 1; j <= m; j++) {
            if (b[i][j] > 0 && !vad[i] && pos[b[i][j]] != i) {
                gh[pos[b[i][j]]].push_back(i);
                gh[i].push_back(pos[b[i][j]]);
            }
        }
    }
    for (int i = 1; i <= n; i++) {
        if (!vad[i] && !vis[i]) {
            com.clear();
            dfs(i);
            Solve(com);
        }
    }
    cout << ans;
    return 0;
}
2025/7/24 20:39
加载中...