什么鬼
查看原帖
什么鬼
384214
esquigybcu楼主2021/10/14 20:22

Code:

#include <stdio.h>
#include <string.h>
#include <limits.h>
#include <algorithm>

typedef long long ll;
typedef unsigned long long ull;

inline constexpr int log2(int n){return 31 - __builtin_clz(n);}
const int N = 3e5 + 5;

struct edge
{
    int u, v, next;
}
e[2 * N]; int cnt, head[N];
inline void add_edge(int u, int v)
{
    e[cnt].u = u, e[cnt].v = v, e[cnt].next = head[u], head[u] = cnt++;
}

int n, a[N], depth[N], top[N][2 * log2(N)];
ll delta[N];

namespace lca
{
    inline void dfs(int u, int f)
    {      
        depth[u] = depth[f] + 1;
        top[u][0] = f;

        for (int i = head[u]; ~i; i = e[i].next)
            if (e[i].v != f)
                dfs(e[i].v, u);
    }

    inline void get()
    {
        for (int k = 1; k <= log2(n); k++)
            for (int i = 1; i <= n; i++)
                top[i][k] = top[top[i][k - 1]][k - 1];
    }

    inline int lca(int u, int v)
    {
        if (depth[u] < depth[v])
            std::swap(u, v);
        while (depth[u] > depth[v])
            u = top[u][log2(depth[u] - depth[v])];
        if (u == v)
            return u;
        for (int k = log2(depth[u]); k >= 0; k--)
            if (top[u][k] != top[v][k])
                u = top[u][k], v = top[v][k];
        return top[u][0];
    }
}

inline void dfs(int u)
{
    for (int i = head[u]; ~i; i = e[i].next)
        if (e[i].v != top[u][0])
        {
            dfs(e[i].v);
            delta[u] += delta[e[i].v];
        }
}

int main()
{
    memset(head, -1, sizeof head);
    depth[0] = -1; 

    scanf("%d", &n);
    for (int i = 0; i < n; i++)
        scanf("%d", &a[i]);
    for (int i = 0; i < n - 1; i++)
    {
        int u, v;
        scanf("%d %d", &u, &v);
        add_edge(u, v), add_edge(v, u);
    }

    lca::dfs(1, 0);
    lca::get();

    for (int i = 0; i < n - 1; i++)
    {
        int u = a[i], v = a[i + 1];
        int l = lca::lca(u, v);
        delta[u]++, delta[v]++, delta[l]--, delta[top[l][0]]--;
    }

    dfs(1);
    for (int i = 1; i < n; i++)
        delta[a[i]]--;
    for (int i = 1; i <= n; i++)
        printf("%lld\n", delta[i]);

    return 0;
}

top[N][2 * log2(N)] 中的 2 * 删掉就 M 了……

2021/10/14 20:22
加载中...