mxqz
查看原帖
mxqz
384214
esquigybcu楼主2022/2/24 19:11

RT

#include <stdio.h>
#include <string.h>
#include <algorithm>

typedef long long ll;
const int N = 3e5 + 5, MOD = 998244353, g = 3, g1 = 332748118;

inline int log2(int n) {return 31 - __builtin_clz(n); }
inline ll qpow(ll x, int k)
{
    if (!k) return 1;
    ll t = qpow(x, k >> 1); t = (t * t) % MOD;
    if (k & 1) t = (t * x) % MOD;
    return t;
}

int rev[N];
inline void getrev(int n)
{
    int l = log2(n);
    for (int i = 0; i < n; i++)
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (l - 1));
}
inline void ntt(ll *a, int n, int type)
{
    for (int i = 0; i < n; i++)
        if (i < rev[i])
            std::swap(a[i], a[rev[i]]);
    for (int lv = 1, l = 0; lv < n; lv <<= 1, l++)
    {
        ll w = qpow(type == 0 ? g : g1, (MOD - 1) >> (l + 1));
        for (int j = 0; j < n; j += (lv << 1))
        {
            ll tmp = 1;
            for (int k = 0; k < lv; k++)
            {
                int x = a[j + k], y = tmp * a[j + k + lv] % MOD;
                a[j + k] = (x + y) % MOD, a[j + k + lv] = (x - y + MOD) % MOD;
                tmp = (tmp * w) % MOD;
            }
        }
    }
    if (type == 1)
    {
        ll inv = qpow(n, MOD - 2);
        for (int i = 0; i < n; i++)
            a[i] = (a[i] * inv) % MOD;
    }
}

inline void inv(ll *a, ll *b, int n)
{
    if (n == 1)
    {
        b[0] = qpow(a[0], MOD - 2);
        return;
    }
    inv(a, b, (n + 1) >> 1);
    int rsg23 = 1;
    while (rsg23 < (n << 1)) rsg23 <<= 1;
    getrev(rsg23);
    
    static ll c[N];
    for (int i = 0; i < n; i++) c[i] = a[i];
    for (int i = n; i < rsg23; i++) c[i] = 0;
    ntt(c, rsg23, 0); ntt(b, rsg23, 0);
    for (int i = 0; i < rsg23; i++)
        b[i] = (2 - c[i] * b[i] % MOD + MOD) % MOD * b[i] % MOD;
    ntt(b, rsg23, 1);
    for (int i = n; i < rsg23; i++) b[i] = 0;
}

inline void diff(ll *a, ll *b, int n)
{
    memset(b, 0, n * sizeof(ll));
    for (int i = 1; i < n; i++)
        b[i - 1] = i * a[i] % MOD;
}
inline void intg(ll *a, ll *b, int n)
{
    memset(b, 0, n * sizeof(ll));
    for (int i = 0; i < n; i++)
        b[i + 1] = a[i] * qpow(i + 1, MOD - 2) % MOD;
}

inline void ln(ll *a, ll *b, int n)
{
    static ll c[N];
    inv(a, b, n);
    diff(a, c, n);
    int rsg23 = 1 << (log2(n) + 2);
    getrev(rsg23);
    ntt(b, rsg23, 0);
    ntt(c, rsg23, 0);
    for (int i = 0; i < rsg23; i++)
        c[i] = (b[i] * c[i]) % MOD;
    ntt(c, rsg23, 1);
    intg(c, b, n);
}

inline void exp(ll *a, ll *b, int n)
{
    if (n == 1)
    {
        b[0] = 1;
        return;
    }
    exp(a, b, (n + 1) >> 1);
    static ll lnb[N]; memset(lnb, 0, sizeof lnb);
    ln(b, lnb, n);
    int rsg23 = 1 << (log2(n) + 2);
    getrev(rsg23);
    for (int i = 0; i < n; i++)
    {
        lnb[i] = a[i] - lnb[i];
        if (lnb[i] < 0) lnb[i] += MOD;
    }
    for (int i = n; i < rsg23; i++)
        lnb[i] = b[i] = 0;
    lnb[0]++;

    ntt(b, rsg23, 0); ntt(lnb, rsg23, 0);
    for (int i = 0; i < rsg23; i++) b[i] = b[i] * lnb[i] % MOD;
    ntt(b, rsg23, 1);
    for (int i = n; i < rsg23; i++) b[i] = 0;
}

ll a[N], b[N];

int main()
{
    int n; ll k;
    scanf("%d", &n);
    k = 0; char c = getchar();
    while (c < '0' || c > '9') c = getchar();
    while ('0' <= c && c <= '9')
        k = (10 * k + c - '0') % MOD, c = getchar();
    for (int i = 0; i < n; i++)
        scanf("%lld", &a[i]);
    
    ln(a, b, n);
    for (int i = 0; i < n; i++) b[i] = (b[i] * k) % MOD;
    exp(b, a, n);
    
    for (int i = 0; i < n; i++)
        printf("%lld ", a[i]);

    return 0;
}
2022/2/24 19:11
加载中...