rt
in
20 1
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
2 20
2 5
2 7
20 13
5 6
5 19
5 4
7 17
7 11
7 3
13 14
13 16
13 8
6 12
6 10
6 9
19 18
4 15
17 1
force
#include<iostream>
#include<vector>
#include<string.h>
#define fir first
#define sec second
#define mp make_pair
using namespace std;
typedef pair<int,int> ___pii;
const int maxn=1001;
int n,__ans,x,a[maxn];
bool e2[1001][1001];
vector<___pii>__tot;
bool __vis[maxn];
void __dfs2(int ___u){
__vis[___u]=true;
for(int i=1;i<=n;i++){
if(i==___u||__vis[i])continue;
if(!e2[i][___u])continue;
__dfs2(i);
}
}
void count(){
for(int i=1;i<=n;i++){
memset(__vis,false,sizeof __vis);
__dfs2(i);int cnt=0;
for(int j=1;j<=n;j++){
if(!__vis[j])continue;
cnt^=a[j];
}
if(cnt!=x)return;
}
__ans++;
}
void __dfs3(int ___u){
if(___u==__tot.size()){
count();
return;
}
__dfs3(___u+1);
___pii y=__tot[___u];
e2[y.fir][y.sec]=e2[y.sec][y.fir]=false;
__dfs3(___u+1);
e2[y.fir][y.sec]=e2[y.sec][y.fir]=true;
}
int main(){
scanf("%lld%lld",&n,&x);
for(int i=1;i<=n;i++){
scanf("%lld",&a[i]);
}
for(int i=1;i<=n-1;i++){
int ___u,___v;
scanf("%lld%lld",&___u,&___v);
e2[___u][___v]=e2[___v][___u]=true;
__tot.push_back(mp(___v,___u));
}
__dfs3(0);
printf("%lld",__ans);
return 0;
}
AC code
#include<iostream>
#include<vector>
#define pb push_back
#define int long long
using namespace std;
int in(){
int k=0,f=1;
char c=getchar();
while(c<'0'||c>'9')
{
if(c=='-')f=-1;
c=getchar();
}
while(c>='0'&&c<='9')k=k*10+c-'0',c=getchar();
return k*f;
}
const int mod=998244353;
const int maxn=10000005;
int n,x,a[maxn];
vector<int>e[maxn];
int f[maxn][2],g[maxn][2];
void dfs(int u,int fa){
int son[maxn],cnt=0;
for(int i=0;i<e[u].size();i++){
int to=e[u][i];
if(to==fa)continue;
son[++cnt]=to;
dfs(to,u);
}
if(cnt==0){
f[u][a[u]]=1;
return;
}
g[1][0]=g[0][0]=1;
for(int i=1;i<=cnt;i++){
g[i][0]=(f[son[i]][x]*g[i-1][0])%mod+(g[i-1][0]*f[son[i]][0])%mod+(g[i-1][1]*f[son[i]][1])%mod;
g[i][1]=(f[son[i]][x]*g[i-1][1])%mod+(g[i-1][0]*f[son[i]][1])%mod+(g[i-1][1]*f[son[i]][0])%mod;
g[i][0]%=mod,g[i][1]%=mod;
}
f[u][1]=g[cnt][!a[u]]%mod;
f[u][0]=g[cnt][a[u]]%mod;
}
signed main(){
n=in(),x=in();
for(int i=1;i<=n;i++)a[i]=in();
for(int i=1;i<=n-1;i++){
int u=in(),v=in();
e[u].pb(v);e[v].pb(u);
}
dfs(1,-1);
printf("%lld",f[1][x]%mod);
return 0;
}
mix code ->TLE
```cpp
#include<iostream>
#include<vector>
#define pb push_back
#define int long long
#include<string.h>
#define fir first
#define sec second
#define mp make_pair
using namespace std;
typedef pair<int,int> ___pii;
using namespace std;
int in(){
int k=0,f=1;
char c=getchar();
while(c<'0'||c>'9')
{
if(c=='-')f=-1;
c=getchar();
}
while(c>='0'&&c<='9')k=k*10+c-'0',c=getchar();
return k*f;
}
const int mod=998244353;
const int maxn=10000005;
int n,x,a[maxn];
vector<int>e[maxn];
bool e2[1001][1001];
vector<___pii>__tot;
int __ans;
bool __vis[maxn];
void __dfs2(int ___u){
__vis[___u]=true;
for(int i=1;i<=n;i++){
if(i==___u||__vis[i])continue;
if(!e2[i][___u])continue;
__dfs2(i);
}
}
void count(){
for(int i=1;i<=n;i++){
memset(__vis,false,sizeof __vis);
__dfs2(i);int cnt=0;
for(int j=1;j<=n;j++){
if(!__vis[j])continue;
cnt^=a[j];
}
if(cnt!=x)return;
}
__ans++;
}
void __dfs3(int ___u){
if(___u==__tot.size()){
count();
return;
}
__dfs3(___u+1);
___pii y=__tot[___u];
e2[y.fir][y.sec]=e2[y.sec][y.fir]=false;
__dfs3(___u+1);
e2[y.fir][y.sec]=e2[y.sec][y.fir]=true;
}
int f[maxn][2],g[maxn][2];
void dfs(int u,int fa){
int son[maxn],cnt=0;
for(int i=0;i<e[u].size();i++){
int to=e[u][i];
if(to==fa)continue;
son[++cnt]=to;
dfs(to,u);
}
if(cnt==0){
f[u][a[u]]=1;
return;
}
g[1][0]=g[0][0]=1;
for(int i=1;i<=cnt;i++){
g[i][0]=(f[son[i]][x]*g[i-1][0])%mod+(g[i-1][0]*f[son[i]][0])%mod+(g[i-1][1]*f[son[i]][1])%mod;
g[i][1]=(f[son[i]][x]*g[i-1][1])%mod+(g[i-1][0]*f[son[i]][1])%mod+(g[i-1][1]*f[son[i]][0])%mod;
g[i][0]%=mod,g[i][1]%=mod;
}
f[u][1]=g[cnt][!a[u]]%mod;
f[u][0]=g[cnt][a[u]]%mod;
}
void part1(){
for(int i=1;i<=n-1;i++){
int ___u,___v;
scanf("%lld%lld",&___u,&___v);
e2[___u][___v]=e2[___v][___u]=true;
__tot.push_back(mp(___v,___u));
}
__dfs3(0);
printf("%lld",__ans);
}
void part2(){
for(int i=1;i<=n-1;i++){
int u=in(),v=in();
e[u].pb(v);e[v].pb(u);
}
dfs(1,-1);
printf("%lld",f[1][x]%mod);
}
signed main(){
n=in(),x=in();
for(int i=1;i<=n;i++)a[i]=in();
if(n<=20)part1();
else part2();
return 0;
}
why?