#include <bits/stdc++.h>
#define pb push_back
#define fi first
#define se second
#define sz size
#define all(x) x.begin(), x.end()
#define pii pair<int, int>
#define int long long
#define mem(x, y) memset(x, y, sizeof(x))
using namespace std;
const int N = 1e5 + 5;
int n, m, a[N];
struct SegTree {
int l, r, l0, l1, r0, r1, cnt1, mx0, mx1;
int tag0, tag1, tag;
SegTree friend operator +(SegTree a, SegTree b) {
SegTree ret;
ret.l = a.l;
ret.r = b.r;
ret.l0 = a.l0 + (a.l0 == a.r - a.l + 1 ? b.l0 : 0);
ret.l1 = a.l1 + (a.l1 == a.r - a.l + 1 ? b.l1 : 0);
ret.r0 = b.r0 + (b.r0 == b.r - b.l + 1 ? a.r0 : 0);
ret.r1 = b.r1 + (b.r1 == b.r - b.l + 1 ? a.r1 : 0);
ret.cnt1 = a.cnt1 + b.cnt1;
ret.mx0 = max(max(a.mx0, b.mx0), a.r0 + b.l0);
ret.mx1 = max(max(a.mx1, b.mx1), a.r1 + b.l1);
return ret;
}
}t[N << 2];
#define ls p << 1
#define rs p << 1 | 1
void build(int p, int l, int r) {
t[p].l = l; t[p].r = r;
t[p].tag = t[p].tag0 = t[p].tag1 = 0;
if (l == r) {
t[p].l0 = t[p].r0 = t[p].mx0 = !a[l];
t[p].l1 = t[p].r1 = t[p].cnt1 = t[p].mx1 = a[l];
return;
}
int mid = l + r >> 1;
build(ls, l, mid);
build(rs, mid + 1, r);
t[p] = t[ls] + t[rs];
}
/*
l, r, l0, l1, r0, r1, cnt1, mx0, mx1;
tag0, tag1, tag;
*/
void spread(int p) {
if (t[p].tag0) {
t[ls].l0 = t[ls].r0 = t[ls].mx0 = t[ls].r - t[ls].l + 1;
t[ls].l1 = t[ls].r1 = t[ls].cnt1 = t[ls].mx1 = 0;
t[ls].tag0 = 1; t[ls].tag1 = t[ls].tag = 0;
t[rs].l0 = t[rs].r0 = t[rs].mx0 = t[rs].r - t[rs].l + 1;
t[rs].l1 = t[rs].r1 = t[rs].cnt1 = t[rs].mx1 = 0;
t[rs].tag0 = 1; t[rs].tag1 = t[rs].tag = 0;
t[p].tag0 = t[p].tag1 = t[p].tag = 0;
}
if (t[p].tag1) {
t[ls].l1 = t[ls].r1 = t[ls].mx1 = t[ls].cnt1 = t[ls].r - t[ls].l + 1;
t[ls].l0 = t[ls].r0 = t[ls].mx0 = 0;
t[ls].tag1 = 1; t[ls].tag0 = t[ls].tag = 0;
t[rs].l1 = t[rs].r1 = t[rs].mx1 = t[rs].cnt1 = t[rs].r - t[rs].l + 1;
t[rs].l0 = t[rs].r0 = t[rs].mx0 = 0;
t[rs].tag1 = 1; t[rs].tag0 = t[rs].tag = 0;
t[p].tag0 = t[p].tag1 = t[p].tag = 0;
}
if (t[p].tag) {
swap(t[ls].l0, t[ls].l1);
swap(t[ls].r0, t[ls].r1);
swap(t[ls].mx0, t[ls].mx1);
t[ls].cnt1 = t[ls].r - t[ls].l + 1 - t[ls].cnt1;
if (t[ls].tag0) t[ls].tag0 = 0, t[ls].tag1 = 1;
else if (t[ls].tag1) t[ls].tag0 = 1, t[ls].tag1 = 0;
else t[ls].tag ^= 1;
swap(t[rs].l0, t[rs].l1);
swap(t[rs].r0, t[rs].r1);
swap(t[rs].mx0, t[rs].mx1);
t[rs].cnt1 = t[rs].r - t[rs].l + 1 - t[rs].cnt1;
if (t[rs].tag0) t[rs].tag0 = 0, t[rs].tag1 = 1;
else if (t[rs].tag1) t[rs].tag0 = 1, t[rs].tag1 = 0;
else t[rs].tag ^= 1;
t[p].tag = 0;
}
}
void mdf(int p, int l, int r, int op) {
spread(p);
if (t[p].l >= l && t[p].r <= r) {
if (op == 0) {
t[p].l0 = t[p].r0 = t[p].mx0 = t[p].r - t[p].l + 1;
t[p].l1 = t[p].r1 = t[p].cnt1 = t[p].mx1 = 0;
t[p].tag0 = 1; t[p].tag1 = t[p].tag = 0;
} else if (op == 1) {
t[p].l1 = t[p].r1 = t[p].mx1 = t[p].cnt1 = t[p].r - t[p].l + 1;
t[p].l0 = t[p].r0 = t[p].mx0 = 0;
t[p].tag1 = 1; t[p].tag0 = t[p].tag = 0;
} else {
swap(t[p].l0, t[p].l1);
swap(t[p].r0, t[p].r1);
swap(t[p].mx0, t[p].mx1);
t[p].cnt1 = t[p].r - t[p].l + 1 - t[p].cnt1;
t[p].tag ^= 1;
}
return;
}
int mid = t[p].l + t[p].r >> 1;
if (l <= mid) mdf(ls, l, r, op);
if (r > mid) mdf(rs, l, r, op);
t[p] = t[ls] + t[rs];
}
int qrycnt(int p, int l, int r) {
spread(p);
if (t[p].l >= l && t[p].r <= r) return t[p].cnt1;
int mid = t[p].l + t[p].r >> 1, ret = 0;
if (l <= mid) ret += qrycnt(ls, l, r);
if (r > mid) ret += qrycnt(rs, l, r);
return ret;
}
SegTree qrymx(int p, int l, int r) {
spread(p);
if (t[p].l >= l && t[p].r <= r) return t[p];
SegTree ret; ret.l = -1;
int mid = t[p].l + t[p].r >> 1;
if (l <= mid) ret = qrymx(ls, l, r);
if (r > mid) {
if (ret.l == -1) ret = qrymx(rs, l, r);
else ret = ret + qrymx(rs, l, r);
}
return ret;
}
signed main() {
ios::sync_with_stdio(false);
cin.tie(0); cout.tie(0);
cin >> n >> m;
for (int i = 1; i <= n; i++) cin >> a[i];
build(1, 1, n);
for (int i = 1, op, l, r; i <= m; i++) {
cin >> op >> l >> r; l++; r++;
if (op <= 2) mdf(1, l, r, op);
else if (op == 3) cout << qrycnt(1, l, r) << '\n';
else cout << qrymx(1, l, r).mx1 << '\n';
}
return 0;
}
在上面这份代码的mdf,qry函数中,我一开始写的是
if (t[p].l >= l && t[p].r <= r) ...
spread(p)
得了20pts ,改成
spread(p)
if (t[p].l >= l && t[p].r <= r) ...
就过了,但感觉好像差不多?救救孩子