关于树上背包
查看原帖
关于树上背包
1062683
lottle1212__楼主2025/1/5 07:32

为什么如下写法 1 会 T,写法 2 就没有问题?

#include <iostream>
#include <algorithm>
#include <string.h>
#include <iomanip>
#include <bitset>
#include <math.h>
#include <string>
#include <vector>
#include <queue>
#include <set>
#include <map>
#define fst first
#define scd second
#define db double
#define ll long long
#define mp make_pair
#define pb push_back
#define eb emplace_back
#define vi vector <int>
#define pii pair <int, int>
#define sz(x) ((int)x.size())
#define ms(f, x) memset(f, x, sizeof(f))
#define L(i, j, k) for (int i=(j); i<=(k); ++i)
#define R(i, j, k) for (int i=(j); i>=(k); --i)
#define ACN(i, H_u) for (int i=H_u; i; i=E[i].nxt)
using namespace std;
template <typename INT> void rd(INT &res) {
	res=0; bool f=false; char ch=getchar();
	while (ch<'0'||ch>'9') f|=ch=='-', ch=getchar();
	while (ch>='0'&&ch<='9') res=(res<<1)+(res<<3)+(ch^48), ch=getchar();
	res=(f?-res:res);
}
template <typename INT, typename...Args>
void rd(INT &x, Args &...y) { rd(x), rd(y...); }
//dfs
const int mod=998244353;
const int maxn=5e3;
const int N=maxn+10;
int n, m, fac[N], inv[N], f[N][N], sz[N][N], g[N], col[N], H[N], tmp[N], edge_cnt;
//wmr
struct Edge { int nxt, to; } E[N<<1];
void add(int u, int v) { E[++edge_cnt]={H[u], v}; H[u]=edge_cnt; }
int moda(int x) { return x>=mod?x-mod:x; }
int mods(int x) { return x<0?x+mod:x;}
int modas(int x) { return moda(mods(x)); }
int quick_power(int x, int y) {
	int res=1;
	while (y) {
		if (y&1) res=(ll)res*x%mod;
		x=(ll)x*x%mod, y>>=1;
	}
	return res;
}
int C(int n, int m) { return (ll)fac[n]*inv[m]%mod*inv[n-m]%mod; }
//incra
void dfs(int u, int pre) {
	++sz[u][col[u]]; f[u][0]=1;
	ACN(i, H[u]) {
		int v=E[i].to;
		if (v==pre) continue;
		dfs(v, u);


  //写法1
		sz[u][0]+=sz[v][0], sz[u][1]+=sz[v][1];
		L(j, 0, sz[u][0]+sz[u][1]) tmp[j]=f[u][j], f[u][j]=0;
		L(j, 0, min(sz[u][0], sz[u][1]))
  		L(k, 0, min(j, min(sz[v][0], sz[v][1])))
    		f[u][j]=moda(f[u][j]+(ll)f[v][k]*tmp[j-k]%mod);

    //写法2
		int su=sz[u][0]+sz[u][1], sv=sz[v][0]+sz[v][1];
		L(j, 0, su+sv) tmp[j]=f[u][j], f[u][j]=0;
		L(j, 0, min(su, m))
			L(k, 0, min(sv, m-j))
				f[u][j+k]=moda(f[u][j+k]+(ll)tmp[j]*f[v][k]%mod);
		sz[u][0]+=sz[v][0], sz[u][1]+=sz[v][1];
	}
	R(i, min(sz[u][0], sz[u][1]), 1) f[u][i]=moda(f[u][i]+(ll)f[u][i-1]*(sz[u][col[u]^1]-i+1)%mod);
}
//lottle
signed main() {
//	ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
	freopen("P6478.in", "r", stdin);
	freopen("P6478.out", "w", stdout);
	rd(n); m=n>>1;
	fac[0]=1; L(i, 1, m) fac[i]=(ll)fac[i-1]*i%mod;
	inv[m]=quick_power(fac[m], mod-2); R(i, m-1, 0) inv[i]=(ll)inv[i+1]*(i+1)%mod;
	L(i, 1, n) scanf("%1d", &col[i]);
	L(i, 1, n-1) { int u, v; rd(u, v); add(u, v); add(v, u); }
	dfs(1, 0);
	L(i, 0, m) g[i]=(ll)f[1][i]*fac[m-i]%mod;
	L(i, 0, m) {
		int res=0;
		L(j, i, m) res=modas(res+(ll)(j-i&1?-1:1)*C(j, i)*g[j]%mod);
		printf("%d\n", res);
	}
	return 0;
}
/*
input
8
10010011
1 2
1 3
2 4
2 5
5 6
3 7
3 8
output
0
10
10
4
0
*/
2025/1/5 07:32
加载中...