以下是某题的代码,答案正确但是会超时。
#include <bits/stdc++.h>
using namespace std;
const int N = 505, M = 1e5 + 5, MOD = 998244353;
int n, m, ans;
int C[M + N][N], f[M][N], sq[M];
inline int madd(int x, int y) { return x + y - (x + y > MOD ? MOD : 0); }
inline int mmul(int x, int y) { return 1ll * x * y % MOD; }
int main()
{
scanf("%d%d", &n, &m);
for (int i = 0; i <= m + n; i++)
C[i][0] = 1;
for (int i = 1; i <= m + n; i++)
for (int j = 1; j <= n; j++)
C[i][j] = madd(C[i - 1][j], C[i - 1][j - 1]);
for (int i = 0; i <= m; i++)
sq[i] = 4ll * i * i % MOD;
int cnt = 0;
for (int i = 1; i <= n; i++)
for (int j = i; j <= m; j++)
ans = madd(ans, mmul(mmul(mmul(mmul(C[m - j + n - 1][n - 1], C[j - 1][i - 1]), C[j + n - i - 1][n - i - 1]), C[n][n - i]), sq[j])), cnt++;
printf("%d", ans);
return 0;
}
但是,如果把这段
for (int i = 1; i <= n; i++)
for (int j = i; j <= m; j++)
ans = madd(ans, mmul(mmul(mmul(mmul(C[m - j + n - 1][n - 1], C[j - 1][i - 1]), C[j + n - i - 1][n - i - 1]), C[n][n - i]), sq[j])), cnt++;
改成这样
for (int j = 1; j <= m; j++)
for (int i = min(n, j); i >= 1; i--)
ans = madd(ans, mmul(mmul(mmul(mmul(C[m - j + n - 1][n - 1], C[j - 1][i - 1]), C[j + n - i - 1][n - i - 1]), C[n][n - i]), sq[j])), cnt++;
就可以快一倍。这是为什么?