20pts求助!
查看原帖
20pts求助!
192374
wenkaijie楼主2021/1/20 23:56
# include <stdio.h>
# include <iostream>
# include <string.h>
# include <string>
# include <algorithm>
# include <math.h>
# include <queue>
# include <deque>
# include <climits>
# include <stack>
using namespace std;
int n,q;
struct edge
{
	int to,next;
}e[60001];
int head[30001];
void add_edge(int x,int y,int cnt)
{
	e[cnt]=(edge){y,head[x]};
	head[x]=cnt;
}
int a[30001];
int dep[30001],fa[30001],son[30001],size[30001],top[30001],id[30001],what[30001],cnt;
void dfs1(int x,int f)
{
	dep[x]=dep[f]+1;
	fa[x]=f;
	size[x]=1;
	int maxsize=0,pos=0;
	for(int i=head[x];i;i=e[i].next)
		if(e[i].to!=f)
		{
			dfs1(e[i].to,x);
			size[x]+=size[e[i].to];
			if(size[e[i].to]>maxsize)
			{
				maxsize=size[e[i].to];
				pos=e[i].to;
			}
		}
	son[x]=pos;
}
void dfs2(int x,int root)
{
    if(x==0)
        return ;
	top[x]=root;
	id[x]=++cnt;
	what[cnt]=x;
	dfs2(son[x],root);
	for(int i=head[x];i;i=e[i].next)
		if(e[i].to!=fa[x] && e[i].to!=son[x])
			dfs2(e[i].to,e[i].to);
}
struct node
{
	int l,r,sum,maxa;
	node *left,*right;
	node()
	{
		l=r=sum=maxa=0;
		left=NULL;
		right=NULL;
	}
}*root;
node *build(int l,int r)
{
	node *now=new node;
	now->l=l;
	now->r=r;
	if(l==r)
	{
		now->sum=now->maxa=a[what[l]];
		return now;
	} 
	int mid=(l+r)/2;
	now->left=build(l,mid);
	now->right=build(mid+1,r);
	now->sum=now->left->sum+now->right->sum;
	now->maxa=max(now->left->maxa,now->right->maxa);
	return now;
}
void change(node *now,int x,int y)
{
	if(now->l>x || now->r<x)
		return ;
	if(now->l==now->r)
	{
		now->maxa=now->sum=y;
		return ;
	}
	change(now->left,x,y);
	change(now->right,x,y);
	now->sum=now->left->sum+now->right->sum;
	now->maxa=max(now->left->maxa,now->right->maxa);
}
int getsum(node *now,int l,int r)
{
	if(r<now->l || l>now->r)
		return 0;
	if(l<=now->l && r>=now->r)
		return now->sum;
	return getsum(now->left,l,r)+getsum(now->right,l,r);
}
int getmax(node *now,int l,int r)
{
	if(r<now->l || l>now->r)
		return -2147483647;
	if(l<=now->l && r>=now->r)
		return now->maxa;
	return max(getmax(now->left,l,r),getmax(now->right,l,r));
}
int Getsum(int x,int y)
{
	if(dep[top[x]]<dep[top[y]])
		swap(x,y);
	if(top[x]==top[y])
	{
		if(dep[x]<dep[y])
			swap(x,y);
		return getsum(root,id[y],id[x]);
	}
	int ans=getsum(root,id[top[x]],id[x]);
	x=fa[top[x]];
	ans+=Getsum(x,y);
	return ans;
}
int Getmax(int x,int y)
{
	if(dep[top[x]]<dep[top[y]])
		swap(x,y);
	if(top[x]==top[y])
	{
		if(dep[x]<dep[y])
			swap(x,y);
		return getmax(root,id[y],id[x]);
	}
	int ans=getmax(root,id[top[x]],id[x]);
	x=fa[top[x]];
	ans=max(ans,Getmax(x,y));
	return ans;
}
int main()
{
	scanf("%d",&n);
	for(int i=1;i<n;i++)
	{
		int x,y;
		scanf("%d%d",&x,&y);
		add_edge(x,y,2*i-1);
		add_edge(y,x,2*i);
	}
	for(int i=1;i<=n;i++)
		scanf("%d",&a[i]);
	dfs1(1,0);
	dfs2(1,1);
	root=build(1,n);
	scanf("%d",&q);
	for(int i=1;i<=q;i++)
	{
		char p;
		int x,y;
		scanf(" %c",&p);
		if(p=='C')
		{
			scanf("%*c%*c%*c%*c%*c");
			scanf("%d%d",&x,&y);
			change(root,x,y);
		}
		else
		{
			scanf("%c%*c%*c",&p);
			scanf("%d%d",&x,&y);
			if(p=='M')
				printf("%d\n",Getmax(x,y));
			else
				printf("%d\n",Getsum(x,y));
		}
	}
	return 0;
}

2021/1/20 23:56
加载中...