我的 NTT 以 998244353 为模数时是正确的,以 1e9+7为模数(原根为 5,原根的逆元为 4e8+3)时是错误的
例:求 (5−x)⋅(5−x)
998244353 输出:25 998244343 1
1e9+7 输出:800000031 999999997 200000002
神奇码风:
#include <bits/stdc++.h>
using u32 = unsigned; using u128 = unsigned __int128;
constexpr u32 p(1e9+7), g(5), ig(4e8+3); constexpr u128 e(1), L((e << 84) / p);
//constexpr u32 p(998244353), g(3), ig(332748118);
inline u32 mod(const u128& x) { return x - (x * L >> 84) * p; }
struct mint { u32 x; inline mint(const u32& _ = 0) : x(_) { } };
inline mint operator + (const mint& a, const mint& b) { return (a.x + b.x < p) ? (a.x + b.x) : (a.x + b.x - p); }
inline mint operator - (const mint& a, const mint& b) { return (a.x < b.x) ? (a.x + p - b.x) : (a.x - b.x); }
inline mint operator * (const mint& a, const mint& b) { return mod(e * a.x * b.x); }
inline mint fpow(mint x, u32 y)
{ mint res(1); while (y) { if (y & 1) res = res * x; x = x * x, y >>= 1; } return res; }
using poly = std::vector<mint>;
std::vector<u32> rev;
inline u32 rev_init(const u32& n)
{
u32 len(1), lg(0); while (len <= n) { len <<= 1, lg++; } rev.resize(n); rev[0] = 0;
for (u32 i(1); i < len; i++) { rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (lg - 1)); } return len;
}
inline void NTT (poly& A, const u32& len, const u32& op)
{
A.resize(len); for (u32 i(0); i < len; i++) if (i < rev[i]) std::swap(A[i], A[rev[i]]);
for (u32 m(1), i, k; m < len; m <<= 1)
{
static mint g0, gn, x, y; g0 = fpow(op ? g : ig, (p - 1) / (m << 1));
for (i = 0, gn = 1; i < len; i += (m << 1), gn = 1) for (k = 0; k < m; k++, gn = gn * g0)
{ x = A[i + k], y = gn * A[i + m + k], A[i + k] = x + y, A[i + m + k] = x - y; }
}
if (!op) { mint inv(fpow(len, p - 2)); for (mint& a : A) a = a * inv; }
}
inline poly operator * (poly A, poly B)
{
u32 n(A.size() + B.size() - 1), len(rev_init(n)); poly res; res.resize(len);
NTT(A, len, 1), NTT(B, len, 1); for (u32 i(0); i < len; i++) res[i] = A[i] * B[i];
NTT(res, len, 0), res.resize(n); return res;
}
poly f[N], ans; u32 tot;
signed main(void)
{
f[1] = f[2] = {5, p - 1};
ans = f[1] * f[2]; for (u32 i(0); i < ans.size(); i++) write(ans[i].x);
return 0;
}