RE,求条
查看原帖
RE,求条
984551
SingKwenCat楼主2025/1/17 14:51

思路跟题解的一致,pushup操作调用了merge函数 全RE,不知错误何在,求各位大佬指点

#include <bits/stdc++.h>
#define int long long
#define endl "\n"
using namespace std;

constexpr const int MOD = 19940417;
constexpr const int N = 5e4 + 5;
struct Node {
    int siz;
    int c[25];
    int add, rev;
} sgt[N << 2];
#define lson p<<1
#define rson p<<1|1
string opt;
int n, m, x, y, z;
int a[N], c[N][25];

void combination() {
    sgt[0].c[0] = 1; // why? i dont need to use sgt[0]
    c[0][0] = 1;
    for (int i = 1; i <= n; ++i) {
        c[i][0] = 1;
        for (int j = 1; j <= min(i, 20ll); ++j) 
            c[i][j] = (c[i - 1][j] + c[i - 1][j - 1]) % MOD;
    }
}

Node merge(const Node& a, const Node& b) {
    Node res = {0, {0}, 0, 0};
    res.siz = a.siz + b.siz;
    // memset(res.c, 0, sizeof res.c);
    for (int i = 0; i <= min(a.siz, 20ll); ++i) 
        for (int j = 0; j <= min(b.siz, 20ll); ++j) 
            res.c[i + j] = (res.c[i + j] + a.c[i] * b.c[j] % MOD) % MOD;
        // 这里是卷积,呃呃卡一下傅里叶优化了
    return res;
}

void add(int k, int p) {
    int tmp[25]; tmp[0] = 1;
    for (int i = 1; i <= 20; ++i) tmp[i] = k * tmp[i - 1] % MOD;
    for (int i = min(20ll, sgt[p].siz); i > 0; --i) {
        for (int j = 0; j < i; ++j) {
            sgt[p].c[i] = (sgt[p].c[i] + tmp[i - j] % MOD * sgt[p].c[j] % MOD * c[sgt[p].siz - j][i - j] % MOD) % MOD;
        }
    }
    sgt[p].add = (sgt[p].add + k) % MOD;
}

void rev(int p) {
    for (int i = 1; i <= min(20ll, sgt[p].siz); ++i)
        if (i & 1) sgt[p].c[i] = MOD - sgt[p].c[i];
    sgt[p].rev ^= 1;
    sgt[p].add = MOD - sgt[p].add;
}

void pushdown(int p) {
    if (sgt[p].rev) {
        rev(lson); rev(rson); sgt[p].rev = 0;
    }
    if (sgt[p].add) {
        add(sgt[p].add, lson); add(sgt[p].add, rson);
        sgt[p].add = 0; 
    }
}

void pushup(int p) {
    int sizz = sgt[p].siz, addd = sgt[p].add, revv = sgt[p].rev;
    sgt[p] = merge(sgt[lson], sgt[rson]);
    sgt[p].siz = sizz, sgt[p].add = addd, sgt[p].rev = revv;
}

void build(int l, int r, int p) {
    sgt[p].siz = r - l + 1;
    if (l == r) {
        sgt[p].c[0] = 1;
        sgt[p].c[1] = (a[l] % MOD + MOD) % MOD;
        return ; 
    }
    int mid = (l + r) >> 1;
    build(l, mid, lson); build(mid + 1, r, rson);
    pushup(p);
}

void moda(int x, int y, int k, int l, int r, int p) {
    if (x <= l && r <= y) return add(k, p);
    pushdown(p);
    int mid = (l + r) >> 1;
    if (x <= mid) moda(x, y, k, l, mid, lson);
    if (y  > mid) moda(x, y, k, mid + 1, r, rson);
    pushup(p);
}

void modr(int x, int y, int l, int r, int p) {
    if (x <= l && r <= y) return rev(p);
    pushdown(p);
    int mid = (l + r) >> 1;
    if (x <= mid) modr(x, y, l, mid, lson);
    if (y  > mid) modr(x, y, mid + 1, r, rson);
    pushup(p);
}

Node query(int x, int y, int l, int r, int p) {
    if (x <= l && r <= y) return sgt[p];
    pushdown(p);
    int mid = (l + r) >> 1;
    if (y <= mid) return query(x, y, l, mid, lson);
    if (x  > mid) return query(x, y, mid + 1, r, rson);
    return merge(query(x, y, l, mid, lson), query(x, y, mid + 1, r, rson));
}

void solve() {
    cin >> n >> m;
    for (int i = 1; i <= n; ++i) cin >> a[i];
    combination(); build(1, n, 1);
    for (int _ = 1; _ <= m; ++_) {
        cin >> opt >> x >> y;
        if (opt[0] == 'I') {
            cin >> z;
            moda(x, y, (MOD + z % MOD) % MOD, 1, n, 1);
        } else if (opt[0] == 'R') {
            modr(x, y, 1, n, 1);
        } else {
            cin >> z;
            cout << (query(x, y, 1, n, 1).c[z] % MOD + MOD) % MOD << endl;
        }
    }
}

signed main() {
    std::cin.tie(nullptr) -> sync_with_stdio(false);
    solve();
}
2025/1/17 14:51
加载中...