88pts,今日内悬关
查看原帖
88pts,今日内悬关
481337
Pratty楼主2024/10/24 20:31
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int mod = 998244353;
int n, m, ans, all;
int a[110][2100], s[110], f[110][210], g[110][2100];
signed main() {
	scanf("%lld%lld", &n, &m);
	for (int i = 1; i <= n; i++) {
		for (int j = 1; j <= m; j++) {
			scanf("%lld", &a[i][j]);
			s[i] += a[i][j];
		}
	}
	g[0][0] = 1;
	for (int i = 1; i <= n; i++) {
		for (int j = 0; j <= i; j++) {
			g[i][j] = (g[i - 1][j] + (j?s[i] * g[i - 1][j - 1]:0) % mod) % mod;
		}
	}
	for (int i = 1; i <= n; i++) {
		all = (all + g[n][i]) % mod;
	}
	for (int l = 1; l <= m; l++) {
		memset(f, 0, sizeof(f));
		f[0][n] = 1;
		ans = 0;
		for (int i = 1; i <= n; i++) {
			for (int j = 1; j <= 2 * n; j++) {
				f[i][j] = (f[i - 1][j] + f[i - 1][j - 1] * a[i][l] % mod + f[i - 1][j + 1] * (s[i] - a[i][l]) % mod) % mod;
			}
		}
		for (int j = 1; j <= n; j++) {
			ans = (ans + f[n][j + n]) % mod;
		}
		all = (all - ans + mod) % mod;
	}
	printf("%lld", (all + mod) % mod);
	return 0;
}
2024/10/24 20:31
加载中...