自己胡出来的 O(n) 解法(虽然看我代码也知道这东西常数跟个根号一样大),空间复杂度也很小,最大的 DP 数组 46MB 怎么会 MLE?
考虑本题组合意义:找一个红色连通块放一个 a 一个 b,找一个蓝色连通块放一个 c 一个 d。(abcd 可以理解为不同的小球。)
设计 DP:dp[i][r/b][当前连通块放的小球个数(只放一个球不区分ad或cd)][是/否已经选了红颜色的连通块][是/否已经选了蓝颜色的连通块]
是否有效闭合孩子连通块(也就是第三维是 2 且转移时改变后两维)由父亲决定。
思路说完了。
#include <bits/stdc++.h>
using namespace std;
const int mod = 998244353;
#define M (long long)
int n;
char c[2000005];
vector<int> e[2000005];
// 原题 <=> 找一个红色连通块放一个 a 一个 b,找一个蓝色连通块放一个 c 一个 d
int dp[2000005][2][3][2][2];
void dfs(int id, int fa) {
dp[id][0][0][0][0] = dp[id][1][0][0][0] = 1;
dp[id][0][1][0][0] = dp[id][1][1][0][0] = 1;
dp[id][0][2][0][0] = dp[id][1][2][0][0] = 1;
if (c[id] == 'r')
dp[id][1][0][0][0] = dp[id][1][1][0][0] = dp[id][1][2][0][0] = 0;
if (c[id] == 'b')
dp[id][0][0][0][0] = dp[id][0][1][0][0] = dp[id][0][2][0][0] = 0;
for (int i : e[id]) {
if (i == fa)
continue;
dfs(i, id);
long long new_dp[2][3][2][2];
memset(new_dp, 0, sizeof(new_dp));
// 是否闭合孩子连通块由自己决定
for (int C = 0; C <= 1; C++) { // 自己颜色,不闭合儿子连通块
int O = 1 - C;
for (int b = 0; b <= 2; b++) {
// 变颜色而不用闭合的儿子
new_dp[C][b][0][0] += M dp[i][O][0][0][0] * dp[id][C][b][0][0];
new_dp[C][b][0][1] += M dp[i][O][0][0][0] * dp[id][C][b][0][1]
+ M dp[i][O][0][0][1] * dp[id][C][b][0][0];
new_dp[C][b][1][0] += M dp[i][O][0][0][0] * dp[id][C][b][1][0]
+ M dp[i][O][0][1][0] * dp[id][C][b][0][0];
new_dp[C][b][1][1] += M dp[i][O][0][0][0] * dp[id][C][b][1][1]
+ M dp[i][O][0][0][1] * dp[id][C][b][1][0]
+ M dp[i][O][0][1][0] * dp[id][C][b][0][1]
+ M dp[i][O][0][1][1] * dp[id][C][b][0][0];
for (int a = 0; a <= b; a++) { // 不变颜色
new_dp[C][b][0][0] += M dp[i][C][a][0][0] * dp[id][C][b - a][0][0];
new_dp[C][b][0][1] += M dp[i][C][a][0][0] * dp[id][C][b - a][0][1]
+ M dp[i][C][a][0][1] * dp[id][C][b - a][0][0];
new_dp[C][b][1][0] += M dp[i][C][a][0][0] * dp[id][C][b - a][1][0]
+ M dp[i][C][a][1][0] * dp[id][C][b - a][0][0];
new_dp[C][b][1][1] += M dp[i][C][a][0][0] * dp[id][C][b - a][1][1]
+ M dp[i][C][a][0][1] * dp[id][C][b - a][1][0]
+ M dp[i][C][a][1][0] * dp[id][C][b - a][0][1]
+ M dp[i][C][a][1][1] * dp[id][C][b - a][0][0];
}
}
// 1+1=2 产生双倍贡献,这里补一倍
new_dp[C][2][0][0] += M dp[i][C][1][0][0] * dp[id][C][1][0][0];
new_dp[C][2][0][1] += M dp[i][C][1][0][0] * dp[id][C][1][0][1]
+ M dp[i][C][1][0][1] * dp[id][C][1][0][0];
new_dp[C][2][1][0] += M dp[i][C][1][0][0] * dp[id][C][1][1][0]
+ M dp[i][C][1][1][0] * dp[id][C][1][0][0];
new_dp[C][2][1][1] += M dp[i][C][1][0][0] * dp[id][C][1][1][1]
+ M dp[i][C][1][0][1] * dp[id][C][1][1][0]
+ M dp[i][C][1][1][0] * dp[id][C][1][0][1]
+ M dp[i][C][1][1][1] * dp[id][C][1][0][0];
}
for (int b = 0; b <= 2; b++) { // 闭合孩子连通块
new_dp[0][b][0][1] += M dp[i][1][2][0][0] * dp[id][0][b][0][0];
new_dp[0][b][1][1] += M dp[i][1][2][1][0] * dp[id][0][b][0][0];
new_dp[0][b][1][1] += M dp[i][1][2][0][0] * dp[id][0][b][1][0];
new_dp[1][b][1][0] += M dp[i][0][2][0][0] * dp[id][1][b][0][0];
new_dp[1][b][1][1] += M dp[i][0][2][0][1] * dp[id][1][b][0][0];
new_dp[1][b][1][1] += M dp[i][0][2][0][0] * dp[id][1][b][0][1];
}
for (int c = 0; c <= 1; c++)
for (int b = 0; b <= 2; b++)
for (int rb = 0; rb <= 1; rb++)
for (int bb = 0; bb <= 1; bb++)
dp[id][c][b][rb][bb] = new_dp[c][b][rb][bb] % mod;
}
}
signed main() {
cin >> n;
for (int i = 1, u, v; i < n; i++) {
cin >> u >> v;
e[u].push_back(v);
e[v].push_back(u);
}
cin >> c + 1;
dfs(1, 1);
cout << (M dp[1][1][2][1][0] + dp[1][1][0][1][1] + dp[1][0][2][0][1] + dp[1][0][0][1][1]) % mod << endl;
return 0;
}