rt,悬3关,题目差不多是求树上有多少路径异或和大于等于t
#include <bits/stdc++.h>
#define pb push_back
#define PII pair<int, int>
#define fi first
#define se second
//#define int long long
#define lowbit(x) ((x) & (-x))
#define ll long long
#define ull unsigned long long
using namespace std;
const int N = 3e5 + 5;
const ll mod = 998244353;
ll qpow(ll a, ll b) {
ll ans = 1;
while (b) {
if (b & 1)
ans *= a, ans %= mod;
a *= a, a %= mod;
b >>= 1;
}
return ans;
}
int n, dp[N], siz, root, maxx, si[N], pp[N], c[33];
ll tt = 1;
bool vis[N];
ll w[N], T;
vector<int> G[N];
int tr[2][N * 130], lzm[N * 130];
void getroot(int t, int fa) {
dp[t] = 1;
int ma = 0;
for (auto to : G[t]) {
if (to == fa || vis[to])continue;
getroot(to, t);
dp[t] += dp[to];
ma = max(ma, dp[to]);
}
int tot = max(ma, siz - dp[t]);
if (tot < maxx)maxx = tot, root = t;
}
vector<pair<int, ll> > op;
void getdis(int t, ll tot, int fa) {
si[t] = 1;
op.pb({ t, tot });
for (auto to : G[t]) {
if (to == fa || vis[to])
continue;
getdis(to, tot ^ w[to], t);
si[t] += si[to];
}
}
void insert(ll x) {
int b[33] = {};
for (int i = 0; i < 32; i++) {
if (x & (1ll << i))
b[i] = 1;
}
ll p = 1;
for (int i = 31; i >= 0; i--) {
if (!tr[b[i]][p])
tr[b[i]][p] = ++tt;
p = tr[b[i]][p];
lzm[p]++;
}
}
void erase(ll x) {
int b[33] = {};
for (int i = 0; i < 32; i++) {
if (x & (1ll << i))
b[i] = 1;
}
ll p = 1;
for (int i = 31; i >= 0; i--) {
p = tr[b[i]][p];
lzm[p]--;
}
}
int query(ll x) {
int b[33] = {};
for (int i = 0; i < 32; i++) {
if (x & (1ll << i))
b[i] = 1;
}
ll p = 1, sum = 0;
for (int i = 31; i >= 0; i--) {
if (c[i] == 0 && b[i] == 0)
sum += lzm[tr[1][p]], p = tr[0][p];
else if (c[i] == 1 && b[i] == 0)
p = tr[1][p];
else if (c[i] == 0 && b[i] == 1)
sum += lzm[tr[0][p]], p = tr[1][p];
else if (c[i] == 1 && b[i] == 1)
p = tr[0][p];
}
return sum + lzm[p];
}
void solve(int t) {
vis[t] = 1;
if (!G[t].size())
return;
vector<ll> sc;
for (auto to : G[t]) {
if (vis[to])
continue;
op.clear();
getdis(to, w[to] ^ w[t], t);
for (auto x : op) {
sc.pb(x.se);
pp[x.fi] += query(x.se ^ w[t]);
if (x.se >= T)
pp[x.fi]++, pp[t]++;
}
for (auto x : op) insert(x.se);
}
for (auto x : sc) erase(x);
for (int qq = G[t].size() - 1; qq >= 0; qq--) {
int to = G[t][qq];
if (vis[to])
continue;
op.clear();
getdis(to, w[to] ^ w[t], t);
for (auto x : op) {
pp[x.fi] += query(x.se ^ w[t]);
}
for (auto x : op) insert(x.se);
}
for (auto x : sc) erase(x);
}
void getans(int t) {
solve(t);
for (auto to : G[t]) {
if (vis[to])
continue;
siz = maxx = dp[to];
root = to;
getroot(to, 0);
getans(root);
}
}
signed main() {
// freopen("weight.in", "r", stdin);
// freopen("weight.out", "w", stdout);
ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
cin >> n;
for (int i = 1; i < n; i++) {
int u, v;
cin >> u >> v;
G[u].pb(v), G[v].pb(u);
}
for (int i = 1; i <= n; i++) cin >> w[i], pp[i] = (w[i] >= T ? 1 : 0);
cin >> T;
for (int i = 0; i < 32; i++) {
if (T & (1ll << i))
c[i] = 1;
}
for (int i = 1; i <= n; i++) pp[i] = (w[i] >= T ? 1 : 0);
siz = maxx = n;
root = 1;
getroot(1, 0);
getans(root);
// ll ans = 1, oo = qpow(n, mod - 2);
// for (int i = 1; i <= n; i++) {
// ans *= (n - pp[i]) * oo % mod;
// ans %= mod;
// }
// cout << (((1ll - ans) % mod + mod) % mod) * qpow(n, n) % mod;
return 0;
}