树剖+线段树 TLE on #4 求助
查看原帖
树剖+线段树 TLE on #4 求助
832472
yishanyi楼主2024/10/15 16:58
#include <bits/stdc++.h>
using namespace std;

#define get getchar()
#define put putchar

inline int read() {
    char x = get; int n = 0, f = 1;
    while (x > '9' || x < '0') {
        if (x == '-') f = -1;
        x = get;
    }
    while (x >= '0' && x <= '9') n = n * 10 + x - '0', x = get;
    return n * f;
}

inline void print(int n) {
    if (n < 0) put('-'), n = -n;
    if (n > 9) print(n / 10);
    put(n % 10 + '0');
}

#define N 300010
#define ls (u << 1)
#define rs (u << 1 | 1)
#define mid ((l + r) >> 1)

int n, q[N], ans[N];
int fa[N], siz[N], dep[N], son[N], top[N], tot, num[N];

vector<int> e[N];

struct {
   int sum, add;
} t[N << 2];

inline void up(int u) { t[u].sum = t[ls].sum + t[rs].sum; }

inline void add(int u, int x) {
   t[u].sum += x;
   t[u].add += x;
}

inline void down(int u) {
   add(ls, t[u].add), add(rs, t[u].add), t[u].add = 0;
}

inline void add(int u, int l, int r, int ql, int qr, int x) {
   if (l > qr || r < ql) return;
   if (l >= ql && r <= qr) {
       add(u, x);
       return;
   }
   down(u);
   add(ls, l, mid, ql, qr, x), add(rs, mid + 1, r, ql, qr, x), up(u);
}

inline int qry(int u, int l, int r, int k) {
   if (l == r) return t[u].sum;
   down(u);
   if (k <= mid) return qry(ls, l, mid, k);
   else return qry(rs, mid + 1, r, k);
}

inline void dfs(int u) {
    dep[u] = dep[fa[u]] + 1;
    siz[u] = 1;
    for (auto v : e[u]) {
        if (v == fa[u]) continue;
        fa[v] = u;
        dfs(v);
        siz[u] += siz[v];
        if (siz[v] > siz[son[u]]) son[u] = v;
    }
}

inline void dfs(int u, int tp) {
    top[u] = tp, num[u] = ++tot;
    if (son[u]) dfs(son[u], tp);
    for (auto v : e[u]) {
        if (v == fa[u] || v == son[u]) continue;
        dfs(v, v);
    }
}

inline void cgt(int u, int v, int x) {
    while (top[u] != top[v]) {
        if (dep[top[u]] < dep[top[v]]) swap(u, v);
        add(1, 1, n, num[top[u]], num[u], x);
        u = fa[top[u]];
    }
    if (dep[u] > dep[v]) swap(u, v);
    add(1, 1, n, num[u], num[v], x);
}

signed main() {
    n = read();
    for (int i = 1; i <= n; i++) q[i] = read();
    for (int i = 1; i < n; i++) {
        int u = read(), v = read();
        e[u].push_back(v), e[v].push_back(u);
    }
    dfs(1), dfs(1, 1);
    for (int i = 2; i <= n; i++) {
        int u = q[i - 1], v = q[i];
        cgt(u, v, 1);
    }
    ans[q[1]] = qry(1, 1, n, num[q[1]]);
    for (int i = 2; i <= n; i++) ans[q[i]] = qry(1, 1, n, num[q[i]]) - 1;
    for (int i = 1; i <= n; i++) print(ans[i]), put('\n');
	return 0;
}

我已经很努力了

2024/10/15 16:58
加载中...