以下的代码中第150行如果注释起来就会导致WA,否则虽然复杂度伪了但是还有40pts。
150行是一行调试,调用print函数,其中print输出的部分已经注释起来了,所以只有一个功能就是把整个树所有的标记下传。但是我已经在merge函数中调用了pushdown,按理说不会出现这种情况啊?
求调玄关。
(数组可能还开小了,导致RE,这不用管)
#include <algorithm>
#include <iostream>
#define int long long
using namespace std;
char wbuf[(1 << 21) + 1], *p3 = wbuf;
#define flush (fwrite(wbuf, 1, p3 - wbuf, stdout), p3 = wbuf)
#define putchar(__x__) (p3 == wbuf + (1 << 21) ? flush : p3, (*p3++) = (__x__))
#define endl putchar('\n')
#define space putchar(' ')
void write(int x) {
static int stk[100], top;
if (!x) return void(putchar('0'));
if (x < 0) putchar('-'), x = -x;
top = 0;
while (x) stk[++top] = x % 10, x /= 10;
while (top) putchar(stk[top--] + '0');
}
void write(const char* str) {
for (int i = 0; str[i]; i++) putchar(str[i]);
}
char buf[1 << 21], *p1, *p2;
#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++)
int read() {
int x = 0, f = 1;
char ch = getchar();
while (ch < '0' || ch > '9') f = (ch == '-' ? -1 : f), ch = getchar();
while (ch >= '0' && ch <= '9') x = (x << 1) + (x << 3) + (ch ^ 48), ch = getchar();
return x * f;
}
constexpr int N = 3e5 + 10, M = 2e7 + 10, mod = 998244353;
int quick_power(int base, int power) {
int res = 1;
while (power) {
if (power & 1) res = res * base % mod;
base = base * base % mod;
power >>= 1;
}
return res;
}
int n, m, ch[N][2], a[N], b[N], id[N], inv1e4, idx, pool[M], pooltop, res[N], ans;
struct {
int ls, rs, sum, tag;
} e[N << 2];
int getnode() {
int res = (pooltop ? pool[pooltop--] : ++idx);
e[res].ls = e[res].rs = e[res].sum = 0, e[res].tag = 1;
return res;
}
void pushnode(int u) {
pool[++pooltop] = u;
}
void pushup(int u) {
e[u].sum = (e[e[u].ls].sum + e[e[u].rs].sum) % mod;
}
void pushdown(int u) {
if (!u || e[u].tag == 1) return;
if (e[u].ls)
e[e[u].ls].sum = e[e[u].ls].sum * e[u].tag % mod, e[e[u].ls].tag = e[e[u].ls].tag * e[u].tag % mod;
if (e[u].rs)
e[e[u].rs].sum = e[e[u].rs].sum * e[u].tag % mod, e[e[u].rs].tag = e[e[u].rs].tag * e[u].tag % mod;
e[u].tag = 1;
}
int build(int pos, int val, int l, int r) {
// write("build: "), write(pos), space, write(val), space, write(l), space, write(r), endl;
int u = getnode();
if (l == r) return e[u].sum = val, u;
int mid = (l + r) >> 1;
if (pos <= mid) e[u].ls = build(pos, val, l, mid);
else e[u].rs = build(pos, val, mid + 1, r);
return pushup(u), u;
}
int merge(int u, int v, int sul, int sur, int svl, int svr, int pr, int l, int r) {
// write("merge: "), write(u), space, write(v), space, write(l), space, write(r), endl;
if (!u && !v) return 0;
pushdown(u), pushdown(v);
if (!v) {
int delta = (pr * svl % mod + (mod + 1 - pr) * svr % mod) % mod;
e[u].sum = e[u].sum * delta, e[u].tag = e[u].tag * delta;
// write("merge only u at "), write(u), space, write(delta), endl;
return u;
}
if (!u) {
int delta = (pr * sul % mod + (mod + 1 - pr) * sur % mod) % mod;
e[v].sum = e[v].sum * delta, e[v].tag = e[v].tag * delta;
// write("merge only v at "), write(v), space, write(delta), endl;
return v;
}
int mid = (l + r) >> 1;
int csul = e[e[u].ls].sum, csur = e[e[u].rs].sum, csvl = e[e[v].ls].sum, csvr = e[e[v].rs].sum;
e[u].ls = merge(e[u].ls, e[v].ls, sul, (sur + csur) % mod, svl, (svr + csvr) % mod, pr, l, mid);
e[u].rs = merge(e[u].rs, e[v].rs, (sul + csul) % mod, sur, (svl + csvl) % mod, svr, pr, mid + 1, r);
return pushup(u), u;
}
void expand(int u, int l, int r) {
// write("expand: "), write(u), space, write(l), space, write(r), space, write(e[u].sum), endl, flush;
if (!u) return;
if (l == r) {
res[l] = e[u].sum;
return;
}
int mid = (l + r) >> 1;
pushdown(u);
expand(e[u].ls, l, mid);
expand(e[u].rs, mid + 1, r);
}
void print(int u, int l, int r, int dep) {
if (!u) return;
// for (int i = 0; i < dep; i++) putchar('-');
// write(u), write(": ["), write(l), write(", "), write(r), write("], sum="), write(e[u].sum), endl;
if (l == r) return;
int mid = (l + r) >> 1;
pushdown(u);
print(e[u].ls, l, mid, dep + 1);
print(e[u].rs, mid + 1, r, dep + 1);
}
int dfs(int u) {
if (!ch[u][0]) {
int rt = build(a[u], 1, 1, m);
// write("1-dfs at "), write(u), space, write(a[u]), write(":\n");
// print(rt, 1, m, 0);
return rt;
}
if (!ch[u][1]) {
int rt = dfs(ch[u][0]);
// write("2-dfs at "), write(u), write(":\n");
// print(rt, 1, m, 0);
return rt;
}
int ru = dfs(ch[u][0]), rv = dfs(ch[u][1]);
int rt = merge(ru, rv, 0, 0, 0, 0, a[u], 1, m);
// write("3-dfs at "), write(u), write(":\n");
print(rt, 1, m, 0); //----------here---------
// pushdown(rt);
return rt;
}
signed main() {
inv1e4 = quick_power(10000, mod - 2);
n = read();
read();
for (int i = 2; i <= n; i++) {
int fa = read();
ch[fa][ch[fa][0] != 0] = i;
}
for (int i = 1; i <= n; i++) {
int val = read();
if (ch[i][0]) a[i] = val * inv1e4 % mod;
else b[++m] = a[i] = val;
}
sort(b + 1, b + m + 1);
for (int i = 1; i <= n; i++)
if (!ch[i][0]) a[i] = lower_bound(b + 1, b + m + 1, a[i]) - b;
// write("a: ");
// for (int i = 1; i <= n; i++) write(a[i]), space;
// endl;
int rt = dfs(1);
expand(rt, 1, m);
// write("res: ");
// for (int i = 1; i <= m; i++) write(res[i]), space; endl;
for (int i = 1; i <= m; i++) ans = (ans + i * b[i] % mod * res[i] % mod * res[i] % mod) % mod;
write(ans), endl;
return flush, 0;
}