T了两个点
查看原帖
T了两个点
352352
artalter楼主2022/1/1 11:17

代码如下

#include<bits/stdc++.h>
using namespace std;
inline int ls(int x){return x<<1;}
inline int rs(int x){return x<<1|1;}
const int maxn=2e5+10;
struct edge
{
	int next,v;
}e[2*maxn];
struct point
{
	int l,r,maxx,sum;
}tree[4*maxn];
int rt,n,q,r,a[maxn],cnt,head[maxn*2],f[maxn],d[maxn],size[maxn],son[maxn],rk[maxn],top[maxn],id[maxn],num;
void add(int u,int v)
{
	num++;
	e[num].v=v;
	e[num].next=head[u];
	head[u]=num;
}
void dfs1(int u,int fa,int depth)
{
	f[u]=fa;
	d[u]=depth;
	size[u]=1;
	for(int i=head[u];i;i=e[i].next)
	{
		int v=e[i].v;
		if(v==fa)continue;
		dfs1(v,u,depth+1);
		size[u]+=size[v];
		if(size[u]>size[son[u]] || !son[u])
		{
			son[u]=v;
		}
	}
}
void dfs2(int u,int t)
{
	top[u]=t;
	id[u]=++cnt;
	rk[cnt]=u;
	if(!son[u])return;
	dfs2(son[u],t);
	for(int i=head[u];i;i=e[i].next)
	{
		int v=e[i].v;
		if(v!=son[u]&&v!=f[u])
		{
			dfs2(v,v);
		}
	}
}
void build(int p,int l,int r)
{
	int mid=(l+r)>>1;
	tree[p].l=l;
	tree[p].r=r;
	if(l==r)
	{
		tree[p].maxx=tree[p].sum=a[rk[mid]];
		return;
	}
	build(ls(p),l,mid);
	build(rs(p),mid+1,r);
	tree[p].maxx=max(tree[ls(p)].maxx,tree[rs(p)].maxx);
	tree[p].sum=tree[ls(p)].sum+tree[rs(p)].sum;
}
int query_max(int p,int l,int r)
{
	if(tree[p].l>=l&&tree[p].r<=r)
	{
		return tree[p].maxx;
	}
	int mid=(tree[p].l+tree[p].r)>>1;
	int s=-0x7fffff;
	if(l<=mid)
	{
		s=max(query_max(ls(p),l,r),s);
	}
	if(r>mid)
	{
		s=max(query_max(rs(p),l,r),s);
	}
	return s;
}


int query_sum(int p,int l,int r)
{
	if(tree[p].l>=l&&tree[p].r<=r)
	{
		return tree[p].sum;
	}
	int mid=(tree[p].l+tree[p].r)>>1;
	int s=0;
	if(l<=mid)
	{
		s+=query_sum(ls(p),l,r);
	}
	if(r>mid)
	{
		s+=query_sum(rs(p),l,r);
	}
	return s;
}
void change(int p,int l,int x)
{
	if(tree[p].l==tree[p].r)
	{
		tree[p].sum=x;
		tree[p].maxx=x;
		return;
	}
	int mid=(tree[p].l+tree[p].r)>>1;
	if(l<=mid)
	{
		change(ls(p),l,x);
	}
	if(l>mid)
	{
		change(rs(p),l,x);
	}
	tree[p].sum=tree[ls(p)].sum+tree[rs(p)].sum;
	tree[p].maxx=max(tree[ls(p)].maxx,tree[rs(p)].maxx);
}
int find_sum(int x,int y)
{
	int ret=0,fx=top[x],fy=top[y];
	while(fx!=fy)
	{
		if(d[fx]>=d[fy])
		{
			ret+=query_sum(1,id[fx],id[x]);
			x=f[fx];
			fx=top[x];
		}else
		{
			ret+=query_sum(1,id[fy],id[y]);
			y=f[fy];
			fy=top[y];
		}
	}
	if(id[x]<=id[y])ret+=query_sum(1,id[x],id[y]);
	else ret+=query_sum(1,id[y],id[x]);
	return ret;
}
int find_max(int x,int y)
{
	int ret=-21474836,fx=top[x],fy=top[y];
	while(fx!=fy)
	{
		if(d[fx]>=d[fy])
		{
			ret=max(query_max(1,id[fx],id[x]),ret);
			x=f[fx];
			fx=top[x];
		}else
		{
			ret=max(query_max(1,id[fy],id[y]),ret);
			y=f[fy];
			fy=top[y];
		}
	}
	if(id[x]<=id[y])ret=max(query_max(1,id[x],id[y]),ret);
	else ret=max(query_max(1,id[y],id[x]),ret);
	return ret;
}
int main()
{
	cin>>n;
	for(int i=1;i<=n-1;i++)
	{
		int u,v;
		scanf("%d%d",&u,&v);
		add(u,v);
		add(v,u);
	}
	for(int i=1;i<=n;i++)
	{
		cin>>a[i];
	}
	dfs1(1,0,1);
	dfs2(1,1);
	build(1,1,n);
	
	cin>>q;
	for(int i=1;i<=q;i++)
	{
		string s;
		cin>>s;
		if(s=="QSUM")
		{
			int x,y;
			scanf("%d%d",&x,&y);
			printf("%d\n",find_sum(x,y));
		}else if(s=="QMAX")
		{
			int x,y;
			scanf("%d%d",&x,&y);
			printf("%d\n",find_max(x,y));
		}else
		{
			int x,t;
			scanf("%d%d",&x,&t);
			change(1,id[x],t);
		}
	}
	return 0;
}
2022/1/1 11:17
加载中...