蒟蒻求助树剖
查看原帖
蒟蒻求助树剖
159959
虫洞吞噬者楼主2021/11/7 22:11

RT,蒟蒻刚刚开始学习树剖,按照OI Wiki上的思路敲出了这份代码,并且大致地和题解对照了一下,并没有发现什么重大的逻辑错误,但一直在WA(一片红)

code:

#include<string>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
int n,m,ans,cnt,tot;
int num[100100],head[100100],son[100100],fa[100100],tp[100100],dep[100100],siz[100100],rk[100100],dfn[100100];
struct Edge{
	int nxt,to;
}edge[200200];
struct Tree{
	int l,r,sum,lz,maxn;
}tree[400400];
void add(int from,int to)
{
	edge[++cnt].nxt=head[from];
	edge[cnt].to=to;
	head[from]=cnt;
}
void dfs1(int s,int pre)
{
	fa[s]=pre;
	dep[s]=dep[pre]+1;
	siz[s]=1;
	for(int i=head[s];i;i=edge[i].nxt)
	{
		int nxt=edge[i].to;
		if(nxt==pre)continue;
		dfs1(nxt,s);
		siz[s]+=siz[nxt];
		if(siz[nxt]>siz[son[s]])son[s]=nxt;
	}
}
void dfs2(int s,int top)
{
	tp[s]=top;
	dfn[s]=++tot;
	rk[tot]=s;
	if(son[s])dfs2(son[s],top);
	for(int i=head[s];i;i=edge[i].nxt)
	{
		int nxt=edge[i].to;
		if(nxt==fa[s]||nxt==son[s])continue;
		dfs2(nxt,nxt);
	}
}
inline void pushup(int id)
{
	tree[id].sum=tree[id*2].sum+tree[id*2+1].sum;
	tree[id].maxn=max(tree[id*2].maxn,tree[id*2+1].maxn);
}
void build(int id,int l,int r)
{
	tree[id].l=l;tree[id].r=r;
	if(l==r)
	{
		tree[id].maxn=tree[id].sum=num[rk[l]];
		return;
	}
	int mid=(l+r)/2;
	build(id*2,l,mid);
	build(id*2+1,mid+1,r);
	pushup(id);
}
void change(int id,int p,int k)
{
	if(tree[id].l==tree[id].r&&tree[id].l==p)
	{
		tree[id].maxn=tree[id].sum=k;
		return;
	}
	int mid=(tree[id].l+tree[id].r)/2;
	if(p<=mid)change(id*2,p,k);
	if(p>mid)change(id*2+1,p,k);
	pushup(id);
}
int finds(int id,int l,int r)
{
	if(l<=tree[id].l&&tree[id].r<=r)return tree[id].sum;
	int sum=0,mid=(tree[id].l+tree[id].r)/2;
	if(l<=mid)sum+=finds(id*2,l,r);
	if(r>mid)sum+=finds(id*2+1,l,r);
	return sum;
}
int findm(int id,int l,int r)
{
	if(l<=tree[id].l&&tree[id].r<=r)return tree[id].maxn;
	int maxn=-1,mid=(tree[id].l+tree[id].r)/2;
	if(l<=mid)maxn=max(maxn,findm(id*2,l,r));
	if(r>mid)maxn=max(maxn,findm(id*2+1,l,r));
	return maxn;
}
int sumpath(int x,int y)
{
	int sum=0;
	while(tp[x]!=tp[y])
	{
		if(dep[tp[x]]<dep[tp[y]])swap(x,y);
		sum+=finds(1,dfn[tp[x]],dfn[x]);
		x=tp[x];
		x=fa[x];
	}
	if(dep[x]<dep[y])swap(x,y);
	sum+=finds(1,dfn[y],dfn[x]);
	return sum;
}
int maxpath(int x,int y)
{
	int maxn=-1000000000;
	while(tp[x]!=tp[y])
	{
		if(dep[tp[x]]<dep[tp[y]])swap(x,y);
		maxn=max(maxn,findm(1,dfn[tp[x]],dfn[x]));
		x=tp[x];
		x=fa[x];
	}
	if(dep[x]<dep[y])swap(x,y);
	maxn=max(maxn,findm(1,dfn[y],dfn[x]));
	return maxn;
}
int main()
{
	scanf("%d",&n);
	for(int i=1;i<n;++i)
	{
		int a,b;
		scanf("%d%d",&a,&b);
		add(a,b);
		add(b,a);
	}
	for(int i=1;i<=n;++i)scanf("%d",&num[i]);
	dfs1(1,0);
	fa[1]=1;
	dfs2(1,1);
	build(1,1,n);
	scanf("%d",&m);
	for(int i=1;i<=m;++i)
	{
		string s;
		int x,y;
		cin>>s;
		scanf("%d%d",&x,&y);
		if(s=="CHANGE")change(1,dfn[x],y);
		else if(s=="QMAX")printf("%d\n",maxpath(x,y));
		else printf("%d\n",sumpath(x,y));
	}
	return 0;
}
2021/11/7 22:11
加载中...