# include <bits/stdc++.h>
# include <iostream>
using namespace std;
struct Node
{
int p;
struct Node* nxt;
};
struct Node p[100005];
int c[100005];
int cc[100005];
int fa[100005];
int dep[100005];
int tot[100005];
int son[100005];
int idx[100005];
int top[100005];
int cnt;
int n;
struct Tree
{
int res;
int l;
int r;
};
struct Tree Tree[100005*4];
int tag[100005*4];
struct Node* ini()
{
struct Node* tmp = (struct Node*) malloc (sizeof(struct Node));
return tmp;
}
void add_edge(int x,int y)
{
struct Node* tmp1 = ini();
tmp1->p = y;
tmp1->nxt = p[x].nxt;
struct Node* tmp2 = ini();
tmp2->p = x;
tmp2->nxt = p[y].nxt;
p[x].nxt = tmp1;
p[y].nxt = tmp2;
return ;
}
void dfs1(int x,int f)
{
fa[x] = f;
dep[x] = dep[f]+1;
tot[x] = 1;
son[x] = 0;
int maxson=0;
struct Node* tmp = p[x].nxt;
while (tmp != NULL)
{
if (tmp->p == f)
{
tmp = tmp->nxt;
continue;
}
dfs1(tmp->p,x);
tot[x] += tot[tmp->p];
if (tot[tmp->p] > maxson)
{
maxson = tot[tmp->p];
son[x] = tmp->p;
}
tmp = tmp->nxt;
}
return ;
}
void dfs2(int x,int f)
{
top[x] = f;
cnt++;
idx[x] = cnt;
cc[cnt] = c[x];
if (son[x] == 0)
{
return ;
}
dfs2(son[x],f);
struct Node* tmp = p[x].nxt;
while (tmp != NULL)
{
if (tmp->p == fa[x] || tmp->p == son[x])
{
tmp = tmp->nxt;
continue;
}
dfs2(tmp->p,tmp->p);
tmp = tmp->nxt;
}
return ;
}
void push_up(int node)
{
Tree[node].res = Tree[node*2].res + Tree[node*2+1].res;
if (Tree[node*2].r == Tree[node*2+1].l)
{
Tree[node].res--;
}
Tree[node].l = Tree[node*2].l;
Tree[node].r = Tree[node*2+1].r;
return ;
}
void build(int node,int l,int r)
{
if (l == r)
{
Tree[node].res = 1;
Tree[node].l = Tree[node].r = cc[l];
return ;
}
int mid = (l+r)/2;
build(node*2,l,mid);
build(node*2+1,mid+1,r);
push_up(node);
return ;
}
void add_tag(int node,int w)
{
Tree[node].res = 1;
Tree[node].l = Tree[node].r = w;
tag[node] = w;
return ;
}
void push_down(int node)
{
if (tag[node])
{
add_tag(node*2+1,tag[node]);
add_tag(node*2,tag[node]);
tag[node] = 0;
}
return ;
}
void upd(int node,int l,int r,int tl,int tr,int w)
{
if (tl <= l && r <= tr)
{
add_tag(node,w);
return ;
}
int mid = (l+r) / 2;
push_down(node);
if (mid >= tl)
{
upd(node*2,l,mid,tl,tr,w);
}
if (mid < tr)
{
upd(node*2+1,mid+1,r,tl,tr,w);
}
push_up(node);
return ;
}
struct Tree query(int node,int l,int r,int tl,int tr)
{
if (tl <= l && r <= tr)
{
return Tree[node];
}
int mid = (l+r)/2;
struct Tree t1,t2,t3;
int c1=0,c2=0;
push_down(node);
if (mid >= tl)
{
t1 = query(node*2,l,mid,tl,tr);
c1=1;
}
if (mid < tr)
{
t2 = query(node*2+1,mid+1,r,tl,tr);
c2=1;
}
if (c1 == c2)
{
t3.res = t1.res + t2.res;
if (t1.r == t2.l)
{
t3.res--;
}
t3.l = t1.l;
t3.r = t2.r;
}
else if (c1)
{
t3 = t1;
}
else if (c2)
{
t3 = t2;
}
return t3;
}
void opt2(int x,int y,int w)
{
while (top[x] != top[y])
{
if (dep[top[x]] < dep[top[y]]) swap(x,y);
upd(1,1,n,idx[top[x]],idx[x],w);
x = fa[top[x]];
}
if (dep[x] > dep[y]) swap(x,y);
upd(1,1,n,idx[x],idx[y],w);
return ;
}
int opt1(int x,int y)
{
int ans = 0;
while (top[x] != top[y])
{
if (dep[top[x]] < dep[top[y]]) swap(x,y);
ans += query(1,1,n,idx[top[x]],idx[x]).res;
int top_fa_color = query(1,1,n,idx[fa[top[x]]],idx[fa[top[x]]]).l;
int top_color = query(1,1,n,idx[top[x]],idx[top[x]]).l;
if (top_fa_color == top_color)
{
ans--;
}
x = fa[top[x]];
}
if (dep[x] > dep[y]) swap(x,y);
ans += query(1,1,n,idx[x],idx[y]).res;
return ans;
}
int main (void)
{
//freopen("P2486.in","w",stdin);
int m;
scanf ("%d %d",&n,&m);
for (int i=1;i<=n;i++)
{
scanf ("%d",&c[i]);
}
for (int i=0;i<n-1;i++)
{
int x,y;
scanf("%d %d",&x,&y);
add_edge(x,y);
}
dfs1(1,0);
dfs2(1,1);
build(1,1,n);
getchar();
for (int i=0;i<m;i++)
{
char opt;
scanf ("%c ",&opt);
if (opt == 'Q')
{
int x,y;
scanf ("%d %d",&x,&y);
printf ("%d\n",opt1(x,y));
}
else if (opt == 'C')
{
int x,y,cc;
scanf ("%d %d %d",&x,&y,&cc);
opt2(x,y,cc);
}
getchar();
}
return 0;
}
样例过了但数据点全部too short on line 1
是输入问题吗 ?