哇实在调不出来了,麻烦大家了
#include <cstdio>
#include <stack>
#define Fa(x) (node[x].fa)
#define Ls(x) (node[x].ch[0])
#define Rs(x) (node[x].ch[1])
inline void read(int& x)
{
x = 0; char ch; int f = 1;
while ((ch = getchar()) < '0' || ch > '9')
f = (ch ^ '-' ? 1 : -1);
while (x = (x << 1) + (x << 3) + ch - '0',
(ch = getchar()) >= '0' && ch <= '9') ;
x *= f;
}
const int mod = 51061, N = 1e5 + 10;
struct Splay{ int fa, ch[2], sum, add, mul, val, size, revFlag; }node[N];
inline void swap(int& a, int& b)
{ int t = a; a = b; b = t; }
inline void pushup(int x)
{
node[x].sum = (node[Ls(x)].sum + node[Rs(x)].sum + node[x].val) % mod;
node[x].size = node[Ls(x)].size + node[Rs(x)].size + 1;
}
inline void mul(int x, int v)
{
node[x].val = 1ll * node[x].val * v % mod;
node[x].sum = 1ll * node[x].sum * v % mod;
node[x].mul = 1ll * node[x].mul * v % mod;
node[x].add = 1ll * node[x].add * v % mod;
}
inline void add(int x, int v)
{
node[x].val = (node[x].val + v) % mod;
node[x].sum = (node[x].sum + 1ll * node[x].size * v % mod) % mod;
node[x].add = (node[x].add + v) % mod;
}
inline void reverse(int x)
{ swap(Ls(x), Rs(x)); node[x].revFlag ^= 1; }
inline void pushdown(int x)
{
if (node[x].mul != 1)
{
if (Ls(x)) mul(Ls(x), node[x].mul);
if (Rs(x)) mul(Rs(x), node[x].mul);
node[x].mul = 1;
}
if (node[x].add)
{
if (Ls(x)) add(Ls(x), node[x].add);
if (Rs(x)) add(Rs(x), node[x].add);
node[x].add = 0;
}
if (node[x].revFlag)
{
if (Ls(x)) reverse(Ls(x));
if (Rs(x)) reverse(Rs(x));
node[x].revFlag = 0;
}
}
inline int get(int x) { return Rs(Fa(x)) == x ? 1 : (Ls(Fa(x)) == x ? 0 : -1); }
inline void rotate(int x)
{
int fa = Fa(x), gf = Fa(fa), d = get(x), dd = get(fa);
if (dd != -1) node[gf].ch[dd] = x;
node[fa].ch[d] = node[x].ch[d ^ 1]; Fa(node[x].ch[d ^ 1]) = fa;
node[x].ch[d ^ 1] = fa; Fa(fa) = x; Fa(x) = gf;
pushup(fa); pushup(x);
}
inline void splay(int x)
{
std::stack<int> st; int t = x;
while (get(t) != -1) st.push(t), t = Fa(t);
while (!st.empty()) pushdown(st.top()), st.pop();
while (get(x) != -1)
{
if (get(Fa(x)) != -1)
rotate(get(Fa(x)) == get(x) ? Fa(x) : x);
rotate(x);
}
pushup(x);
}
inline void access(int x)
{
int y = 0;
while (x)
{
splay(x); Rs(x) = y; pushup(x);
y = x; x = Fa(x);
}
}
inline void makeroot(int x) { access(x); splay(x); reverse(x); }
inline void split(int x, int y) { makeroot(y); access(x); splay(x); }
inline void link(int x, int y) { makeroot(y); Fa(y) = x; }
inline void cut(int x, int y) { split(x, y); Fa(y) = Ls(x) = 0; }
int main()
{
int n, q; char op[5]; read(n); read(q);
for (int i = 1; i <= n; ++i)
node[i].size = node[i].mul = node[i].val = 1;
for (int i = 1, x, y; i < n; ++i)
read(x), read(y), link(x, y);
for (int i = 1, a, b, c, d; i <= q; ++i)
{
scanf("%s", op);
switch (op[0])
{
case '+':
read(a); read(b); read(c);
split(a, b); add(a, c);
break;
case '-':
read(a); read(b); read(c); read(d);
cut(a, b); link(c, d);
break;
case '*':
read(a); read(b); read(c);
split(a, b); mul(a, c);
break;
case '/':
read(a); read(b);
split(a, b); printf("%d\n", node[a].sum);
break;
}
}
return 0;
}