求助次小生成树
查看原帖
求助次小生成树
327139
纯白楼主2021/8/8 19:45

wa#4,#10
不开o2的话#9,#10还会TLE
这题应该没有重边把
代码

#include <iostream>
#include <algorithm>
#define int long long
#define ls (o<<1)
#define rs (o<<1|1)
#define mid (s+((t-s)>>1))

using namespace std;

const int maxn=5e6;

struct node1 {
    int from, to, val;
    bool operator < (const node1 b) const{
        return val < b.val;
    }
}edg[maxn];
struct node2{
    int to, val, nex;
}e[maxn];
struct heapnode{
    int f,s;
}sumv[maxn];

int n,m,u,v,vale;
int cnts,hd[maxn],cnt;
int bs[maxn],fa[maxn],ff[maxn],de[maxn],tp[maxn],sz[maxn],dfn[maxn],num[maxn],val[maxn],ranks[maxn];

//bs为重儿子,ff是并查集的代表元,ranks为并查集的秩,val为节点的权值,根节点为0

long long ans,dea;
int root;
bool vis[maxn];

inline void adde(int from, int to, int val)
{
    e[++cnts] = node2{to, val, hd[from]};
    hd[from] = cnts;
}

inline int find(int x) { return x==ff[x] ? x : ff[x]=find(ff[x]); }
inline void merge(int x, int y)
{
    x=find(x),y=find(y);
    if (ranks[x] > ranks[y])
    {
        ff[y] = x;
        ranks[x]++;
    }
    else 
    {
        ff[x] = y;
        ranks[y]++;
    }
}
void kru()
{
    for (int i=1; i<=n; i++) ff[i] = i;
    sort(edg+1, edg+1+m);
    int ss=n;
    for (int i=1; i<=m&&ss>1; i++)
    {
        u=edg[i].from,v=edg[i].to,vale=edg[i].val;
        if (find(u) != find(v))
        {
            merge(u,v);
            vis[i] = 1;
            if (val[u])
                val[v] = vale;
            else 
                val[u] = vale;
            adde(u, v ,vale);
            adde(v, u, vale);
            ans += vale;
        }
    }
}

inline void dfs1(int u, int en)
{
    fa[u] = en;
    sz[u] = 1;
    de[u] = de[en]+1;
    for (int i=hd[u]; i; i=e[i].nex)
    {
        if (e[i].to == en)
            continue;
        dfs1(e[i].to, u);
        sz[u] += sz[e[i].to];
        if (sz[e[i].to] > sz[bs[u]])
            bs[u] = e[i].to;
    }
}

inline void dfs2(int u, int en)
{
    tp[u] = en;
    dfn[u] = ++cnt;
    num[cnt] = u;
    if (bs[u]) dfs2(bs[u], en);
    for (int i=hd[u]; i; i=e[i].nex)
    {
        if (e[i].to == fa[u] || e[i].to == bs[u])
            continue;
        dfs2(e[i].to, e[i].to); 
    }
}
inline int getm(int x, int y, int z, int w)
{
    int a[5]={x, y, z, w};
    sort(a,a+4);
    //cout <<"----" << *(a+3) <<"---" <<"\n";
    for (int i=2; i>=0; i--)
    {
        if (a[i] != a[3])
            return a[i];
    }
    return 0;
}
inline void push_up(int o)
{
    sumv[o].f = max(sumv[ls].f, sumv[rs].f);
    sumv[o].s = getm(sumv[ls].f, sumv[rs].f,sumv[ls].s, sumv[rs].s);
}

inline void build(int o, int s, int t)
{
    if (s==t) {
        sumv[o].f = val[num[s]];
        sumv[o].s = 0;
        return ;
    }
    build(ls, s, mid);
    build(rs, mid+1, t);
    push_up(o);
}

inline heapnode querymax(int o, int l, int r, int s, int t)
{
    if (l<=s && r>=t)
        return sumv[o];
    heapnode s1{0,0}, s2{0,0},s3{0,0};
    if (l <= mid) s1 = querymax(ls, l, r, s, mid);
    if (r > mid)  s2 = querymax(rs, l, r, mid+1, t);
    s3.f=max(s1.f, s2.f);
    s3.s=getm(s1.f, s1.s, s2.f, s2.s);
    return s3;
}

heapnode s{0, 0};
heapnode x{0, 0};
inline heapnode lca(int a, int b)
{
    s={0,0};
    while (tp[a] != tp[b])
    {
        if (de[tp[a]] < de[tp[b]])
            swap(a, b);
        x=querymax(1, dfn[tp[a]], dfn[a], 1, n);
        s.s=getm(s.s, s.f, x.s, x.f);
        s.f=max(s.f, x.f);
        a=fa[tp[a]];
    }
    if (de[a] > de[b])
        swap(a, b);
    x=querymax(1, dfn[a], dfn[b], 1, n);
    s.s=getm(s.s, s.f, x.s, x.f);
    s.f=max(s.f, x.f);
    return s;
}
void init()
{
    kru();
    for (int i=1; i<=n; i++)
    {
        if (!val[i])
        {
            root=i;
            break;
        }
    }
    dfs1(root, 0);
    dfs2(root, root);
    build(1, 1, n);
}
signed main ()
{
    ios::sync_with_stdio(false);
    cin >> n >> m;
    for (int i=1; i<=m; i++)
    {
        cin >> u >> v >> vale;
        if (u!=v)
            edg[++cnts] = node1{u, v, vale};
    }
    m=cnts,cnts=0;
    init();
    dea=0x7f7f7f7f7f;
    for (int i=1; i<=m; i++)
    {
        if (vis[i]) continue;
        u=edg[i].from,v=edg[i].to,vale=edg[i].val;
        x=lca(u,v);
        if (x.f==vale && x.s!=0)
            dea=min(dea, (long long)vale-x.s);
        else if (x.f!=vale)
            dea=min(dea, (long long)vale-x.f);
            
    }
    cout << ans + dea;
    return 0; 
}
2021/8/8 19:45
加载中...