为什么不开O2只有80pts(2TLE),开O2后全RE。。。
/*=========================
Source:
Problem:
Author:OYBDOOO
Date:
Result:
=========================*/
#include<bits/stdc++.h>
#define pb push_back
#define int long long
using namespace std;
const int maxn=1e6+10;
const int inf=3e15;
int ttt=0,tot=0;
int rk[maxn],sz[maxn],dep[maxn],fa[maxn],top[maxn],ind[maxn],mi[maxn],zez[maxn],ppp[maxn];
int g[maxn];
vector<int>v[maxn];
int s[maxn],t=0;
struct EDGE{int v,next;}edge[maxn*2];
void add(int aa,int bb){tot++;edge[tot].v=bb;edge[tot].next=ind[aa];ind[aa]=tot;}
bool cmp(int x,int y){return rk[x]<rk[y];}
void dfs1(int x,int fx)
{
sz[x]=1;fa[x]=fx;dep[x]=dep[fx]+1;
for(int i=ind[x];i!=-1;i=edge[i].next)
{
if(fx==edge[i].v)continue;
dfs1(edge[i].v,x);
if(sz[edge[i].v]>sz[zez[x]])zez[x]=edge[i].v;
sz[x]+=sz[edge[i].v];
}
}
void dfs2(int x,int tp)
{
top[x]=tp;rk[x]=++ttt;
if(zez[x])dfs2(zez[x],tp);
for(int i=ind[x];i!=-1;i=edge[i].next)
{
if(zez[x]==edge[i].v||edge[i].v==fa[x])continue;
dfs2(edge[i].v,edge[i].v);
}
}
int lca(int x,int y)
{
while(top[x]!=top[y])
{
if(dep[top[x]]>dep[top[y]])x=fa[top[x]];
else y=fa[top[y]];
}
if(dep[x]<dep[y])return x;
else return y;
}
void ps(int x)
{
if(x==1)return;
int l=lca(x,s[t]);
if(l==s[t]){s[++t]=x;return;}
while(t&&rk[s[t-1]]>rk[l])v[s[t-1]].pb(s[t]),t--;
if(rk[l]>rk[s[t-1]])v[l].pb(s[t]),s[t]=l;
else v[l].pb(s[t]),t--;
s[++t]=x;
}
int mx[maxn],mn[maxn],mxans,mnans,all;
int dp(int x)
{
int i;
sz[x]=ppp[x];g[x]=0;
if(ppp[x])mx[x]=mn[x]=0;
else mx[x]=-inf,mn[x]=inf;
for(i=0;i<v[x].size();i++)
{
int vv=v[x][i];
dp(vv);
int l=dep[vv]-dep[x];
all+=(g[x]+sz[x]*l)*sz[vv]+g[vv]*sz[x];
sz[x]+=sz[vv];g[x]+=g[vv]+l*sz[vv];
mnans=min(mnans,mn[x]+mn[vv]+l);mxans=max(mxans,mx[x]+mx[vv]+l);
mn[x]=min(mn[x],mn[vv]+l),mx[x]=max(mx[x],mx[vv]+l);
}
}
int n,is[maxn];
signed main()
{
int i,aa,bb,cc,T,m;scanf("%lld",&n);
memset(ind,-1,sizeof(ind));
for(i=1;i<n;i++){scanf("%lld%lld",&aa,&bb);add(aa,bb);add(bb,aa);}
dfs1(1,0);dfs2(1,1);
scanf("%lld",&T);
while(T--){
scanf("%lld",&m);
memset(ppp,0,sizeof(ppp));
for(i=1;i<=n;i++)v[i].clear();
for(i=1;i<=m;i++)scanf("%lld",&is[i]),ppp[is[i]]=1;
sort(is+1,is+m+1,cmp);
s[1]=1;t=1;
if(is[1]==2&&is[2]==5)
int my=-1;
for(i=1;i<=m;i++)
ps(is[i]);
while(t>0)v[s[t-1]].pb(s[t]),t--;
// dp(1);
all=0,mnans=inf,mxans=-inf;
dp(1);
cout<<all<<" "<<mnans<<" "<<mxans<<endl;
}
return 0;
}