求助,卡了一天了,现在一直80分,学的题解里第四个大佬的写法。
#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;
const int M = 800005;
#define ll long long
int read()
{
int x=0,f=1;char c;
while((c=getchar())<'0' || c>'9') {if(c=='-') f=-1;}
while(c>='0' && c<='9') {x=(x<<3)+(x<<1)+(c^48);c=getchar();}
return x*f;
}
int n,num,cnt,mi,eg,sz,siz[M],del[2*M],rt[M],lt[M];
int ch[16*M][2];ll dis[M],mx[16*M][2],ans;
struct edge
{
int v,c,next;
};
template<unsigned int M>struct graph
{
int tot,f[M];edge e[2*M];
graph() {tot=1;}//important
void add(int u,int v,int c)
{
e[++tot]=edge{v,c,f[u]},f[u]=tot;
e[++tot]=edge{u,c,f[v]},f[v]=tot;
}
};graph<M> g1,g2;graph<M> g;
void dfs(int u,int fa)//三度化
{
int tmp=u;
for(int i=g1.f[u];i;i=g1.e[i].next)
{
int v=g1.e[i].v,c=g1.e[i].c;
if(v==fa) continue;
g.add(tmp,++num,0);
g.add(num,v,c);
dis[v]=dis[u]+c;
dfs(v,u);
}
}
void dfs2(int u,int fa,ll d,int op)//处理初始形态的边分树
{
if(!op) sz++;
if(u<=n)//不是虚点就维护树
{
int x=++cnt;//新建节点
if(!rt[u]) rt[u]=lt[u]=++cnt;//还没有根
mx[lt[u]][op]=d+dis[u];//维护最值
ch[lt[u]][op]=x;
lt[u]=x;
}
for(int i=g.f[u];i;i=g.e[i].next)
{
int v=g.e[i].v,c=g.e[i].c;
if(v==fa || del[i]) continue;
dfs2(v,u,d+c,op);
}
}
void get(int u,int fa)//找到最好的边
{
siz[u]=1;//???一开始没打这个
for(int i=g.f[u];i;i=g.e[i].next)
{
int v=g.e[i].v;
if(v==fa || del[i]) continue;
get(v,u);
siz[u]+=siz[v];
if(max(siz[v],sz-siz[v])<mi)
mi=max(siz[v],sz-siz[v]),eg=i;
}
}
void solve(int x,int s)//边分治
{
if(s==1) return ;//如果已经没有边了
mi=1e9;eg=0;sz=s;
get(x,0);
del[eg]=del[eg^1]=1;//删去这条边
int u=g.e[eg].v,v=g.e[eg^1].v;//获取这条边的两个端点
sz=0;//顺便算一下子树大小
dfs2(u,v,0,0);dfs2(v,u,g.e[eg].c,1);
int tmp=s-sz;//一定要先存下来
solve(u,sz);solve(v,tmp);
}
int merge(int x,int y,ll d)//边分树合并
{
if(!x || !y) return x+y;
ans=max(ans,max(mx[x][0]+mx[y][1],mx[y][0]+mx[x][1])+d);
mx[x][0]=max(mx[x][0],mx[y][0]);
mx[x][1]=max(mx[x][1],mx[y][1]);
ch[x][0]=merge(ch[x][0],ch[y][0],d);
ch[x][1]=merge(ch[x][1],ch[y][1],d);
return x;
}
void dfs3(int u,int fa,ll d)//访问第二棵树
{
ans=max(ans,2*(dis[u]-d));//x=y的情况
for(int i=g2.f[u];i;i=g2.e[i].next)
{
int v=g2.e[i].v,c=g2.e[i].c;
if(v==fa) continue;
dfs3(v,u,d+c);
rt[u]=merge(rt[u],rt[v],-2*d);
}
}
signed main()
{
n=num=read();
memset(mx,0xc0,sizeof mx);//important
for(int i=1;i<n;i++)
{
int u=read(),v=read(),c=read();
g1.add(u,v,c);
}
for(int i=1;i<n;i++)
{
int u=read(),v=read(),c=read();
g2.add(u,v,c);
}
dfs(1,0);
solve(1,num);
ans=-1e18;
dfs3(1,0,0);
printf("%lld\n",ans/2);
}