样例过不去求调
查看原帖
样例过不去求调
762646
Piggy343288楼主2025/6/17 11:49

如图,此代码能通过 FFT 板子但过不了这个板子。

#include <bits/stdc++.h>
using namespace std;

const int maxN = 3e6 + 10, mod = 998244353;

inline int qpow(int a, int b) {
    int ans = 1;
    while (b) {
        if (b & 1)
            ans = 1ll * ans * a % mod;
        a = 1ll * a * a % mod;
        b >>= 1;
    }
    return ans;
}

constexpr int g = 3, gi = 332748118;
int a[maxN], b[maxN], c[maxN];

struct _NTT {
    int limit, len, rev[maxN];
    void NTT(int *A, int type) {
        for (int i = 0; i < limit; i++)
            if (i < rev[i])
                swap(A[i], A[rev[i]]);
        for (int i = 1; i < limit; i <<= 1) {
            int gn = qpow((type == 1 ? g : gi), (mod - 1) / (i << 1));
            for (int j = 0; j < limit; j += (i << 1)) {
                int g = 1;
                for (int k = 0; k < i; k++, g = 1ll * g * gn % mod) {
                    int x = A[j + k], y = 1ll * g * A[j + k + i] % mod;
                    A[j + k] = (x + y) % mod,
                          A[j + k + i] = (x - y + mod) % mod;
                }
            }
        }
        if (type == -1) {
            int inv = qpow(limit, mod - 2);
            for (int i = 0; i <= limit; i++)
                A[i] = 1ll * A[i] * inv % mod;
        }
    }
    void mul(int *a, int *b, int *c, int n, int m) {
        limit = 1, len = 0;
        while (limit <= (n + m))
            limit <<= 1, len++;
        for (int i = 0; i < limit; i++) {
            rev[i] = (rev[i >> 1] >> 1) | ((i & 1)) << (len - 1);
        }
        NTT(a, 1);
        NTT(b, 1);
        for (int i = 0; i < limit; i++)
            c[i] = 1ll * a[i] * b[i] % mod;
        NTT(c, -1);
    }
    void inv(int *a, int *b, int *c, int n, int m) {
        limit = 1, len = 0;
        while (limit <= (n + m))
            limit <<= 1, len++;
        for (int i = 0; i < limit; i++) {
            rev[i] = (rev[i >> 1] >> 1) | ((i & 1)) << (len - 1);
        }
        NTT(a, 1);
        NTT(b, 1);
        for (int i = 0; i < limit; i++)
            c[i] = 1ll * b[i] * (2 - 1ll * a[i] * b[i] % mod + mod) % mod;
        NTT(c, -1);
    }
} ntt;

struct Poly {
    int n, a[maxN];
    Poly() {}
    Poly(int deg) { n = deg; }
    Poly trim(int deg) { n = deg; return *this; }
    void input(int deg) {
        n = deg;
        for (int i = 0; i <= deg; i++)
            cin >> a[i];
    }
    void output() {
        for (int i = 0; i <= n; i++)
            cout << a[i] << " ";
    }
    Poly operator+(Poly &rhs) {
        int len = rhs.n > n ? rhs.n : n;
        Poly ans(len);
        for (int i = 0; i <= len; i++) {
            ans.a[i] = a[i] + rhs.a[i];
            if (ans.a[i] >= mod)
                ans.a[i] -= mod;
        }
        return ans;
    }
    Poly operator+(int x) {
        Poly ans = *this;
        ans.a[0] += x;
        if (ans.a[0] >= mod)
            ans.a[0] -= mod;
        return ans;
    }
    Poly operator-(Poly &rhs) {
        int len = rhs.n > n ? rhs.n : n;
        Poly ans(len);
        for (int i = 0; i <= len; i++) {
            ans.a[i] = a[i] - rhs.a[i];
            if (ans.a[i] < 0)
                ans.a[i] += mod;
        }
        return ans;
    }
    Poly operator-(int x) {
        Poly ans = *this;
        ans.a[0] -= x;
        if (ans.a[0] < 0)
            ans.a[0] += mod;
        return ans;
    }
    Poly operator*(Poly &rhs) {
        Poly ans;
        ntt.mul(a, rhs.a, ans.a, n, rhs.n);
        ans.n = n + rhs.n;
        return ans;
    }
    Poly operator*(int x) {
        int flag = 1;
        if (x < 0) {
            flag = -1, x = -x;
        }
        Poly ans = *this;
        if (x != 1) {
            for (int i = 0; i <= n; i++) {
                ans.a[i] = 1ll * a[i] * x % mod;
            }
        }
        if (flag == -1) {
            for (int i = 0; i <= n; i++) {
                ans.a[i] = mod - ans.a[i];
            }
        }
        return ans;
    }
    Poly operator*=(Poly &rhs) { return *this = *this * rhs; }
    Poly inv(int len) {
        if (len == 1) {
            Poly ans(0);
            ans.a[0] = qpow(a[0], mod - 2);
            return ans;
        }
        Poly f0 = inv((len + 1) >> 1);
        Poly ans(len - 1); ntt.inv(a, f0.a, ans.a, n, f0.n);
        return ans;
    }
} F, G;

signed main() {
    int n;
    cin >> n;
    n--;
    F.input(n);
    (F.inv(n + 1)).output();
}
2025/6/17 11:49
加载中...