调2h了,对着题解实在找不出来错了Orz 求调
查看原帖
调2h了,对着题解实在找不出来错了Orz 求调
137603
zhiyangfanshotacon楼主2021/8/31 19:32

哇实在调不出来了,麻烦大家了

#include <cstdio>
#include <stack>
#define Fa(x) (node[x].fa)
#define Ls(x) (node[x].ch[0])
#define Rs(x) (node[x].ch[1])
inline void read(int& x)
{
    x = 0; char ch; int f = 1;
    while ((ch = getchar()) < '0' || ch > '9')
        f = (ch ^ '-' ? 1 : -1);
    while (x = (x << 1) + (x << 3) + ch - '0',
    (ch = getchar()) >= '0' && ch <= '9') ;
    x *= f;
}
const int mod = 51061, N = 1e5 + 10;
struct Splay{ int fa, ch[2], sum, add, mul, val, size, revFlag; }node[N];
inline void swap(int& a, int& b)
{ int t = a; a = b; b = t; }
inline void pushup(int x)
{ 
    node[x].sum = (node[Ls(x)].sum + node[Rs(x)].sum + node[x].val) % mod; 
    node[x].size = node[Ls(x)].size + node[Rs(x)].size + 1;
}
inline void mul(int x, int v)
{ 
    node[x].val = 1ll * node[x].val * v % mod;  
    node[x].sum = 1ll * node[x].sum * v % mod;
    node[x].mul = 1ll * node[x].mul * v % mod;
    node[x].add = 1ll * node[x].add * v % mod;
}
inline void add(int x, int v)
{ 
    node[x].val = (node[x].val + v) % mod; 
    node[x].sum = (node[x].sum + 1ll * node[x].size * v % mod) % mod;
    node[x].add = (node[x].add + v) % mod;
}
inline void reverse(int x)
{ swap(Ls(x), Rs(x)); node[x].revFlag ^= 1; }
inline void pushdown(int x)
{
    if (node[x].mul != 1)
    {
        if (Ls(x)) mul(Ls(x), node[x].mul);
        if (Rs(x)) mul(Rs(x), node[x].mul);
        node[x].mul = 1;
    }
    if (node[x].add)
    {
        if (Ls(x)) add(Ls(x), node[x].add);
        if (Rs(x)) add(Rs(x), node[x].add);
        node[x].add = 0;
    }
    if (node[x].revFlag)
    {
        if (Ls(x)) reverse(Ls(x));
        if (Rs(x)) reverse(Rs(x));
        node[x].revFlag = 0;
    }
}
inline int get(int x) { return Rs(Fa(x)) == x ? 1 : (Ls(Fa(x)) == x ? 0 : -1); }
inline void rotate(int x)
{
    int fa = Fa(x), gf = Fa(fa), d = get(x), dd = get(fa);
    if (dd != -1) node[gf].ch[dd] = x;
    node[fa].ch[d] = node[x].ch[d ^ 1]; Fa(node[x].ch[d ^ 1]) = fa;
    node[x].ch[d ^ 1] = fa; Fa(fa) = x; Fa(x) = gf;
    pushup(fa); pushup(x);
} 
inline void splay(int x)
{
    std::stack<int> st; int t = x;
    while (get(t) != -1) st.push(t), t = Fa(t);
    while (!st.empty()) pushdown(st.top()), st.pop();
    while (get(x) != -1)
    {
        if (get(Fa(x)) != -1)
            rotate(get(Fa(x)) == get(x) ? Fa(x) : x);
        rotate(x);
    }
    pushup(x);
}
inline void access(int x)
{
    int y = 0;
    while (x)
    {
        splay(x); Rs(x) = y; pushup(x);
        y = x; x = Fa(x);
    }
}
inline void makeroot(int x) { access(x); splay(x); reverse(x); }
inline void split(int x, int y) { makeroot(y); access(x); splay(x); }
inline void link(int x, int y) { makeroot(y); Fa(y) = x; }
inline void cut(int x, int y) { split(x, y); Fa(y) = Ls(x) = 0; }
int main()
{
    int n, q; char op[5]; read(n); read(q);
    for (int i = 1; i <= n; ++i)
        node[i].size = node[i].mul = node[i].val = 1;
    for (int i = 1, x, y; i < n; ++i)
        read(x), read(y), link(x, y);   
    for (int i = 1, a, b, c, d; i <= q; ++i)
    {
        scanf("%s", op);
        switch (op[0])
        {
            case '+':
                read(a); read(b); read(c);
                split(a, b); add(a, c);
                break;
            case '-':
                read(a); read(b); read(c); read(d);
                cut(a, b); link(c, d);
                break;
            case '*':
                read(a); read(b); read(c);
                split(a, b); mul(a, c);
                break;
            case '/':
                read(a); read(b);
                split(a, b); printf("%d\n", node[a].sum);
                break;
        }
    }
    return 0;
}
2021/8/31 19:32
加载中...