求助大佬,代码WA4,讨论区的方法都试了也没用
查看原帖
求助大佬,代码WA4,讨论区的方法都试了也没用
815093
AAAAAZBX楼主2024/10/4 19:55
#include <iostream>
#include <cstring>
#include <algorithm>
#include <queue>
#include <vector>
#include <set>
#include <stack>
#include <map>
#include <cmath>
#include <chrono>
#include <random>
#include <iomanip>
//#define x first
//#define y second
using namespace std;
const int N = 2e5 + 10;
const int P = 131;
const int M = 1e6 + 10;
const int mod = 998244353;
const int INF = 0x3f3f3f3f;
typedef pair<int, int>PII;
typedef long long ll;
typedef unsigned long long ull;
int n, q, root;
int h[N], e[N * 2], ne[N * 2], idx;
int in[N], out[N], pos[N], timestamp;
int dep[N], f[N][20], sz[N];
ll a[N];
struct Node {
    int l, r;
    ll sum, tag;
}tr[N << 2];
void add(int a, int b) {
    e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}
void dfs(int u, int fa) {
    sz[u] = 1;
    in[u] = ++timestamp;
    pos[timestamp] = u;
    dep[u] = dep[fa] + 1;
    f[u][0] = fa;
    for (int i = h[u]; ~i; i = ne[i]) {
        int j = e[i];
        if (j == fa) {
            continue;
        }
        dfs(j, u);
        sz[u] += sz[j];
    }
    out[u] = timestamp;
}
int lca(int a, int b) {
    if (dep[a] < dep[b])swap(a, b);
    for (int i = 19; ~i; i--) {
        if (dep[f[a][i]] >= dep[b]) {
            a = f[a][i];
        }
    }
    if (a == b)return a;
    for (int i = 19; ~i; i--) {
        if (f[a][i] != f[b][i]) {
            a = f[a][i];
            b = f[b][i];
        }
    }
    return f[a][0];
}
int jump(int u, int d) {
    for (int i = 19; ~i; i--) {
        if (d >> i & 1) {
            u = f[u][i];
        }
    }
    return u;
}
void pushup(int u) {
    tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
void pushdown(int u) {
    if (tr[u].tag) {
        tr[u << 1].sum += (tr[u << 1].r - tr[u << 1].l + 1) * tr[u].tag;
        tr[u << 1].tag += tr[u].tag;
        tr[u << 1 | 1].sum += (tr[u << 1 | 1].r - tr[u << 1 | 1].l + 1) * tr[u].tag;
        tr[u << 1 | 1].tag += tr[u].tag;
        tr[u].tag = 0;
    }
}
void build(int u, int l, int r) {
    tr[u] = { l,r,0,0 };
    if (l == r) {
        tr[u].sum = a[pos[l]];
        return;
    }
    int mid = l + r >> 1;
    build(u << 1, l, mid);
    build(u << 1 | 1, mid + 1, r);
    pushup(u);
}
void modify(int u, int l, int r, ll k) {
    if (l > r)return;
    if (tr[u].l >= l && tr[u].r <= r) {
        tr[u].sum += (tr[u].r - tr[u].l + 1) * k;
        tr[u].tag += k;
        return;
    }
    pushdown(u);
    int mid = tr[u].l + tr[u].r >> 1;
    if (l <= mid)modify(u << 1, l, r, k);
    if (r > mid)modify(u << 1 | 1, l, r, k);
    pushup(u);
}
ll query(int u, int l, int r) {
    if (l > r)return 0;
    if (tr[u].l >= l && tr[u].r <= r)return tr[u].sum;
    pushdown(u);
    int mid = tr[u].l + tr[u].r >> 1;
    ll res = 0;
    if (l <= mid)res = query(u << 1, l, r);
    if (r > mid)res += query(u << 1 | 1, l, r);
    return res;
}
int get(int a, int b, int c) {
    int u = lca(a, b);
    int v = lca(a, c);
    int w = lca(b, c);
    //cout << a << " " << b << " " << c << "abc\n";
    //cout << u << " " << v << " " << w << "uvw\n";
    if (dep[u] < dep[v])u = v;
    if (dep[u] < dep[w])u = w;
    return u;
}
void treeModify(int u, ll x) {
    int l = in[u], r = in[u] + sz[u] - 1;
    if (u == root)modify(1, 1, n, x);
    else if (in[root]<l || in[root]>r)modify(1, l, r, x);
    else {
        int tmp = jump(root, dep[root] - dep[u] - 1);
        modify(1, 1, n, x);
        modify(1, in[tmp], in[tmp] + sz[tmp] - 1, -x);
    }
}
ll treeQuery(int u) {
    int l = in[u], r = in[u] + sz[u] - 1;
    if (u == root)return query(1, 1, n);
    else if (in[root]<l || in[root]>r)return query(1, l, r);
    else {
        int tmp = jump(root, dep[root] - dep[u] - 1);
        return query(1, 1, n) - query(1, in[tmp], in[tmp] + sz[tmp] - 1);
    }
}
int main() {
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    cin >> n >> q;
    root = 1;
    for (int i = 1; i <= n; i++) {
        cin >> a[i];
        h[i] = -1;
    }
    for (int i = 1, u, v; i < n; i++) {
        cin >> u >> v;
        add(u, v), add(v, u);
    }
    dfs(1, 0);
    for (int i = 1; i <= n; i++)
        for (int j = 1; (1 << j) <= dep[i]; j++)
            f[i][j] = f[f[i][j - 1]][j - 1];
    build(1, 1, n);
    while (q--) {
        int op, u, v, x;
        cin >> op;
        if (op == 1) {
            cin >> root;
        }
        else if (op == 2) {
            cin >> u >> v >> x;
            treeModify(get(u, v, root), x);
        }
        else {
            cin >> x;
            cout << treeQuery(x) << '\n';
        }
    }
    return 0;
}
2024/10/4 19:55
加载中...