蒟蒻在写多项式乘法时遇到一个奇怪的问题:本地和学校OJ都可编译,但在洛谷上一直提示“编译失败”,代码如下:
#include <bits/stdc++.h>
#define clr(a, n) memset(a, 0, sizeof(int) * n)
#define cpy(a, b, n) memcpy(a, b, sizeof(int) * n)
typedef long long ll;
typedef unsigned long long ull;
typedef double f64;
typedef long double f128;
const int MAXN = 2000005, MOD = 998244353, _G = 3;
int read()
{
int s = 0, f = 1;
char ch = getchar();
while (!(ch >= '0' && ch <= '9'))
{
if (ch == '-')
f = -1;
ch = getchar();
}
while (ch >= '0' && ch <= '9')
s = (s << 3) + (s << 1) + (ch ^ 48), ch = getchar();
return s * f;
}
void pttimes(int *f, int *g, int len)
{
for (int i = 0; i < len; i++)
f[i] = 1ll * f[i] * g[i] % MOD;
}
int tr[MAXN << 1], curn;
void tsetup(int n)
{
if (curn == n)
return ;
curn = n;
for (int i = 0; i < n; i++)
tr[i] = tr[i >> 1] >> 1 | ((i & 1) ? (n >> 1) : 0);
}
ll fpow(ll x, ll y)
{
ll res = 1;
for (; y; y >>= 1, x = x * x % MOD)
if (y & 1)
res = res * x % MOD;
return res;
}
inline ll inv(int x)
{
return fpow(x, MOD - 2);
}
const int invG = inv(_G);
void NTT(int *g, bool fl, int n)
{
tsetup(n);
static ull f[MAXN << 1], w[MAXN << 1] = {1};
for (int i = 0; i < n; i++)
f[i] = (((ll)MOD << 5) + g[tr[i]]) % MOD;
for (int len = 1; len < n; len <<= 1)
{
ull rt = fpow(fl ? _G : invG, (MOD - 1) / (len << 1));
for (int i = 1; i < len; i++)
w[i] = w[i - 1] * rt % MOD;
for (int p = 0; p < n; p += (len << 1))
for (int cur = 0; cur < len; cur++)
{
int t = w[cur] * f[p | len | cur] % MOD;
f[p | len | cur] = f[p | cur] + MOD - t;
f[p | cur] = f[p | cur] + t;
}
if (len == (1 << 10))
for (int i = 0; i < n; i++)
f[i] %= MOD;
}
if (fl)
for (int i = 0; i < n; i++)
g[i] = f[i] % MOD;
else
{
ull invn = inv(n);
for (int i =0; i < n; i++)
g[i] = f[i] % MOD * invn % MOD;
}
}
void polytimes(int *f, int *g, int len, int lim)
{
static int tmp[MAXN << 1];
int n = 1;
for (; n < (len << 1); n <<= 1) ;
clr(tmp, n), cpy(tmp, g, n);
NTT(f, 1, n), NTT(tmp, 1, n);
pttimes(f, tmp, n), NTT(f, 0, n);
clr(f + lim, n - lim), clr(tmp, n);
}
int a[MAXN << 2] = {0}, b[MAXN << 2] = {0}, n = 0, m = 0;
void read(int a[], int &n)
{
char ch = getchar();
while (!(ch >= '0' && ch <= '9'))
ch = getchar();
while (ch >= '0' && ch <= '9')
a[n++] = ch -'0', ch = getchar();
std::reverse(a, a + n);
}
int main()
{
read(a, n), read(b, m);
polytimes(a, b, std::max(n, m), n + m);
n += m;
for (; a[n] == 0; n--) ;
for (int i = 0; i <= n; i++)
{
if (a[i] / 10 && n == i)
n++;
a[i + 1] += a[i] / 10, a[i] %= 10;
}
for (; ~n; n--)
printf("%d", a[n]);
#ifndef ONLINE_JUDGE
system("pause");
#endif
return 0;
}