有关树剖求助!为什么两份类似的代码,在洛谷IDE上测试,一个会爆内存,一个不会
查看原帖
有关树剖求助!为什么两份类似的代码,在洛谷IDE上测试,一个会爆内存,一个不会
400333
qzilr楼主2021/9/2 17:40

这是第一份,会爆内存

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int maxn=1e+5;
struct edge{
	int to,nxt;
}e[maxn<<1];
struct node{
	int l,r,lc,rc,sum,tag;
}t[maxn<<1];
int n,m,rt,mod,f[maxn],d[maxn],siz[maxn],son[maxn],top[maxn],id[maxn],rk[maxn],val[maxn],head[maxn];
int tot=0;
inline void add(int x,int y){
    e[++tot].nxt=head[x];
    e[tot].to=y;
    head[x]=tot;
}
inline void dfs1(int pos){
	d[pos]=d[f[pos]]+1,siz[pos]=1;
	for(int v,i=head[pos];i;i=e[i].nxt )
		if((v=e[i].to )!=f[pos]){
			f[v]=pos,dfs1(v),siz[pos]+=siz[v];
			if(siz[v]>siz[son[pos]])
				son[pos]=v;
		}
}
inline void dfs2(int pos,int t){
	top[pos]=t,id[pos]=++tot,rk[tot]=pos;
	if(son[pos])	dfs2(son[pos],t);
	for(int v,i=head[pos];i;i=e[i].nxt )
		if((v=e[i].to )!=f[pos]&&v!=son[pos])
			dfs2(v,v);
}
inline void push_up(int pos){
	t[pos].sum =(t[t[pos].lc ].sum +t[t[pos].rc ].sum )%mod;
}
inline void build(int pos,int l,int r){
	if(l==r){
		t[pos].sum =val[rk[l]],t[pos].l =t[pos].r =l;
		return;
	}
	int mid=l+r>>1;
	t[pos].lc =++tot,t[pos].rc =++tot;
	build(t[pos].lc ,l,mid);
	build(t[pos].rc ,mid+1,r);
	push_up(pos);
}
inline int get_len(int pos){
	return t[pos].r -t[pos].l +1;
}
inline void push_down(int pos){
	if(t[pos].tag ){
		int ls=t[pos].lc ,rs=t[pos].rc ,lz=t[pos].tag ;
		(t[ls].sum +=get_len(ls)*lz)%=mod,
		(t[rs].sum +=get_len(rs)*lz)%=mod;
		(t[ls].tag +=lz)%=mod,(t[rs].tag +=lz)%=mod;
		t[pos].tag =0;
	}
}
inline void update(int pos,int l,int r,int k){
	if(t[pos].l >=l&&t[pos].r <=r){
		(t[pos].sum +=get_len(pos)*k)%=mod,
		(t[pos].tag +=k)%=mod;
		return;
	}
	push_down(pos);
	int mid=t[pos].l +t[pos].r >>1;
	if(l<=mid)	update(t[pos].lc ,l,r,k);
	if(r>mid)	update(t[pos].rc ,l,r,k);
	push_up(pos);
}
inline int query(int pos,int l,int r){
	if(t[pos].l >=l&&t[pos].r <=r)
		return t[pos].sum %mod;
	push_down(pos);
	int s=0,mid=t[pos].l +t[pos].r >>1;
	if(l<=mid)	s+=query(t[pos].lc ,l,r);
	if(r>mid)	s+=query(t[pos].rc ,l,r);
	return s%mod;
}
inline void updates(int x,int y,int k){
	while(top[x]!=top[y]){
		if(d[x]<d[y])	swap(x,y);
		update(rt,id[top[x]],id[x],k);
		x=f[top[x]];
	}
	if(id[x]>id[y])	swap(x,y);
	update(rt,id[x],id[y],k);
}
inline int get_sum(int x,int y){
	int sum=0;
	while(top[x]!=top[y]){
		if(d[x]<d[y])	swap(x,y);
		sum+=query(rt,id[top[x]],id[x]);
		x=f[top[x]];
	}
	if(id[x]>id[y])	swap(x,y);
	sum+=query(rt,id[x],id[y]);
	return sum%mod;
}
signed main(){
	scanf("%lld%lld%lld%lld",&n,&m,&rt,&mod);
	for(int i=1;i<=n;++i)	scanf("%lld",val+i);
    for(int x,y,i=1;i<n;++i){
		scanf("%lld%lld",&x,&y);
		add(x,y),add(y,x);
	}
	tot=0,dfs1(rt),dfs2(rt,rt);
	tot=0,build(rt=tot++,1,n);
	for(int opt,x,y,z,i=1;i<=m;++i){
		scanf("%lld",&opt);
		switch(opt){
			case 1:{
				scanf("%lld%lld%lld",&x,&y,&z);
				updates(x,y,z);
				break;
			}
			case 2:{
				scanf("%lld%lld",&x,&y);
				printf("%lld\n",get_sum(x,y));
				break;
			}
			case 3:{
				scanf("%lld%lld",&x,&z);
				update(rt,id[x],id[x]+siz[x]-1,z);
				break;
			}
			case 4:{
				scanf("%lld",&x);
				printf("%lld\n",query(rt,id[x],id[x]+siz[x]-1));
				break;
			}
		}
	}
	return 0;
}

这是第二份,不会爆内存

#include<iostream>
#include<cstdio>
#define int long long
using namespace std;
const int maxn=1e5+10;
struct edge{
    int next,to;
}e[maxn*2];
struct node{
    int l,r,ls,rs,sum,lazy;
}a[maxn*2];
int n,m,r,rt,mod,v[maxn],head[maxn],cnt,f[maxn],d[maxn],son[maxn],size[maxn],top[maxn],id[maxn],rk[maxn];
void add(int x,int y)
{
    e[++cnt].next=head[x];
    e[cnt].to=y;
    head[x]=cnt;
}
void dfs1(int x)
{
    size[x]=1,d[x]=d[f[x]]+1;
    for(int v,i=head[x];i;i=e[i].next)
        if((v=e[i].to)!=f[x])
        {
            f[v]=x,dfs1(v),size[x]+=size[v];
            if(size[son[x]]<size[v])
                son[x]=v;
        }
}
void dfs2(int x,int tp)
{
    top[x]=tp,id[x]=++cnt,rk[cnt]=x;
    if(son[x])
        dfs2(son[x],tp);
    for(int v,i=head[x];i;i=e[i].next)
        if((v=e[i].to)!=f[x]&&v!=son[x])
            dfs2(v,v);
}
inline void pushup(int x)
{
    a[x].sum=(a[a[x].ls].sum+a[a[x].rs].sum)%mod;
}
void build(int l,int r,int x)
{
    if(l==r)
    {
        a[x].sum=v[rk[l]],a[x].l=a[x].r=l;
        return;
    }
    int mid=l+r>>1;
    a[x].ls=cnt++,a[x].rs=cnt++;
    build(l,mid,a[x].ls),build(mid+1,r,a[x].rs);
    a[x].l=a[a[x].ls].l,a[x].r=a[a[x].rs].r;
    pushup(x);
}
inline int len(int x)
{
    return a[x].r-a[x].l+1;
}
inline void pushdown(int x)
{
    if(a[x].lazy)
    {
        int ls=a[x].ls,rs=a[x].rs,lz=a[x].lazy;
        (a[ls].lazy+=lz)%=mod,(a[rs].lazy+=lz)%=mod;
        (a[ls].sum+=lz*len(ls))%=mod,(a[rs].sum+=lz*len(rs))%=mod;
        a[x].lazy=0;
    }
}
void update(int l,int r,int c,int x)
{
    if(a[x].l>=l&&a[x].r<=r)
    {
        (a[x].lazy+=c)%=mod,(a[x].sum+=len(x)*c)%=mod;
        return;
    }
    pushdown(x);
    int mid=a[x].l+a[x].r>>1;
    if(mid>=l)
        update(l,r,c,a[x].ls);
    if(mid<r)
        update(l,r,c,a[x].rs);
    pushup(x);
}
int query(int l,int r,int x)
{
    if(a[x].l>=l&&a[x].r<=r)
        return a[x].sum;
    pushdown(x);
    int mid=a[x].l+a[x].r>>1,tot=0;
    if(mid>=l)
        tot+=query(l,r,a[x].ls);
    if(mid<r)
        tot+=query(l,r,a[x].rs);
    return tot%mod;
}
inline int sum(int x,int y)
{
    int ret=0;
    while(top[x]!=top[y])
    {
        if(d[top[x]]<d[top[y]])
            swap(x,y);
        (ret+=query(id[top[x]],id[x],rt))%=mod;
        x=f[top[x]];
    }
    if(id[x]>id[y])
        swap(x,y);
    return (ret+query(id[x],id[y],rt))%mod;
}
inline void updates(int x,int y,int c)
{
    while(top[x]!=top[y])
    {
        if(d[top[x]]<d[top[y]])
            swap(x,y);
        update(id[top[x]],id[x],c,rt);
        x=f[top[x]];
    }
    if(id[x]>id[y])
        swap(x,y);
    update(id[x],id[y],c,rt);
}
signed main()
{
    scanf("%lld%lld%lld%lld",&n,&m,&r,&mod);
    for(int i=1;i<=n;i++)
        scanf("%lld",&v[i]);
    for(int x,y,i=1;i<n;i++)
    {
        scanf("%lld%lld",&x,&y);
        add(x,y),add(y,x);
    }
    cnt=0,dfs1(r),dfs2(r,r);
    cnt=0,build(1,n,rt=cnt++);
    for(int op,x,y,k,i=1;i<=m;i++)
    {
        scanf("%lld",&op);
        if(op==1)
        {
            scanf("%lld%lld%lld",&x,&y,&k);
            updates(x,y,k);
        }
        else if(op==2)
        {
            scanf("%lld%lld",&x,&y);
            printf("%lld\n",sum(x,y));
        }
        else if(op==3)
        {
            scanf("%lld%lld",&x,&y);
            update(id[x],id[x]+size[x]-1,y,rt);
        }
        else
        {
            scanf("%lld",&x);
            printf("%lld\n",query(id[x],id[x]+size[x]-1,rt));
        }
    }
    return 0;
}
2021/9/2 17:40
加载中...