事情是这样的,昨天刚学莫队二次离线,在写 P5501 。然后写了份代码,我确定我的复杂度是正确的(代码放在最后)并且几乎用了能用的卡常手段,但是还是 TLE record,求助大佬分析程序瓶颈。
自己简单进行分析了一下,判断瓶颈出现在这里:
for (int i = 1; i <= n; i++) {
//瓶颈不在 insert,注释过 insert函数和后面的遍历
//发现是 后面那个遍历最慢。
insert(p1, b1, arr[i], 1);
insert(p2, b2, arr[i], arr[i]);
for (int ind = head[i]; ind; ind = nex[ind]) {
Offline it = offline[ind];
int l = it.l, r = it.r, id = it.id;
ll d = it.d;
//我曾统计过这个循环的运行次数,是跟题解一样的,甚至比
//题解的次数要少,而且 query 函数的复杂度是 O(1)
//但是我如果把这两句 query 去掉,就运行的飞快,这是为什么呢?
//难道是关于类型转换?关于cache命中?求dalao优化(已失去所有力气和手段)
for (int x = l; x <= r; x++) {
Q[id].ret -= query(p1, b1, 1, arr[x] - 1) * arr[x] * d;
Q[id].ret -= query(p2, b2, arr[x] + 1, MAX_M) * d;
}
}
}
/* insert 与 query 的实现
void insert(ll _in[MAX_B][MAX_B], ll _out[MAX_B], int ind, ll val) {
int bb = (ind - 1) / zylen + 1;
for (int x = (ind - 1) % zylen + 1; x <= zylen; x++) _in[bb][x] += val;
for (int b = bb; b <= bcnt; b++) _out[b] += val;
}
ll query(ll _in[MAX_B][MAX_B], ll _out[MAX_B], int l, int r) {
ll ret = 0;
if (l > r) return 0;
int aa = (l - 1) / zylen + 1, bb = (r - 1) / zylen + 1;
if (aa == bb) return _in[bb][(r - 1) % zylen + 1] - _in[bb][(l - 1) % zylen];
ret += _out[bb - 1] - _out[aa];
ret += _in[bb][(r - 1) % zylen + 1];
ret += _in[aa][zylen] - _in[aa][(l - 1) % zylen];
return ret;
}*/
整体代码
#include <iostream>
#include <cmath>
#include <algorithm>
#include <cstring>
#include <vector>
using namespace std;
#define MAX_N 500010
#define MAX_M 100000
#define MAX_B 333
#define ll long long
inline int read() {
int ret = 0;
char ch = getchar();
while (ch < '0' || '9' < ch) ch = getchar();
while ('0' <= ch && ch <= '9')
ret = ret * 10 + ch - '0', ch = getchar();
return ret;
}
void write(ll x) {
if (x > 9) write(x / 10);
putchar(x % 10 + '0');
}
int n, m, len;
int get_blo(int i) { return i / len; }
struct Question {
int l, r, id;
ll ret;
bool operator< (const Question& other) const {
int bl = get_blo(l), br = get_blo(other.l);
if (bl != br) return bl < br;
if (bl % 2) return r < other.r;
return r > other.r;
}
} Q[MAX_N];
struct Offline {
int id, l, r;
ll d;
};
int idle = 1;
int head[MAX_N] = { 0 }, nex[MAX_N * 2] = { 0 };
Offline offline[MAX_N * 2];
void add(int x, Offline off) {
nex[idle] = head[x];
head[x] = idle;
offline[idle++] = off;
}
ll arr[MAX_N] = { 0 };
//整块前缀和-普通值域 & 总和值域
ll b1[MAX_B] = { 0 }, b2[MAX_B] = { 0 };
//块内前缀和-普通值域 & 总和值域
ll p1[MAX_B][MAX_B] = { 0 }, p2[MAX_B][MAX_B] = { 0 };
int bcnt, zylen;
void insert(ll _in[MAX_B][MAX_B], ll _out[MAX_B], int ind, ll val) {
int bb = (ind - 1) / zylen + 1;
for (int x = (ind - 1) % zylen + 1; x <= zylen; x++) _in[bb][x] += val;
for (int b = bb; b <= bcnt; b++) _out[b] += val;
}
ll query(ll _in[MAX_B][MAX_B], ll _out[MAX_B], int l, int r) {
ll ret = 0;
if (l > r) return 0;
int aa = (l - 1) / zylen + 1, bb = (r - 1) / zylen + 1;
if (aa == bb) return _in[bb][(r - 1) % zylen + 1] - _in[bb][(l - 1) % zylen];
ret += _out[bb - 1] - _out[aa];
ret += _in[bb][(r - 1) % zylen + 1];
ret += _in[aa][zylen] - _in[aa][(l - 1) % zylen];
return ret;
}
ll _tr1[MAX_M + 10] = { 0 };
ll _tr2[MAX_M + 10] = { 0 };
int lowbit(int i) { return i & (-i); }
void tr_insert(ll tr[], int ind, ll val) {
while (ind <= MAX_M) {
tr[ind] += val;
ind += lowbit(ind);
}
}
ll tr_query(ll tr[], int ind) {
ll ret = 0;
while (ind) {
ret += tr[ind];
ind -= lowbit(ind);
}
return ret;
}
ll pres1[MAX_N] = { 0 }; //比自己小的数数量之和
ll pres2[MAX_N] = { 0 }; //比自己大的数之和
ll pres3[MAX_N] = { 0 }; //arr[i]前缀和
ll ans[MAX_N] = { 0 };
int main() {
zylen = bcnt = 320;
n = read(), m = read();
for (int x = 1; x <= n; x++) arr[x] = read();
for (int x = 1; x <= n; x++) {
pres1[x] = pres1[x - 1] + tr_query(_tr1, arr[x] - 1) * arr[x];
pres2[x] = pres2[x - 1] + (tr_query(_tr2, MAX_M) - tr_query(_tr2, arr[x]));
pres3[x] = pres3[x - 1] + arr[x];
tr_insert(_tr1, arr[x], 1);
tr_insert(_tr2, arr[x], arr[x]);
}
len = sqrt(n);
for (int x = 1; x <= m; x++) {
int l = read(), r = read();
Q[x] = { l, r, x, 0 };
}
sort(Q + 1, Q + 1 + m);
for (int x = 1, l = 1, r = 0; x <= m; x++) {
int L = Q[x].l, R = Q[x].r;
if (r < R) add(l - 1, { x, r + 1, R, 1 });
if (r > R) add(l - 1, { x, R + 1, r, -1 });
r = R;
if (l > L) add(r, { x, L, l - 1, -1 });
if (l < L) add(r, { x, l, L - 1, 1 });
l = L;
}
for (int i = 1; i <= n; i++) {
insert(p1, b1, arr[i], 1);
insert(p2, b2, arr[i], arr[i]);
//用 vector也试过了一样TLE
for (int ind = head[i]; ind; ind = nex[ind]) {
Offline it = offline[ind];
int l = it.l, r = it.r, id = it.id;
ll d = it.d;
for (int x = l; x <= r; x++) {
Q[id].ret -= query(p1, b1, 1, arr[x] - 1) * arr[x] * d;
Q[id].ret -= query(p2, b2, arr[x] + 1, MAX_M) * d;
}
}
}
ll res = 0;
for (int x = 1, l = 1, r = 0; x <= m; x++) {
int L = Q[x].l, R = Q[x].r, id = Q[x].id;
if (r < R) {
res = res + (pres1[R] - pres1[r]) + (pres2[R] - pres2[r])
+ (pres3[R] - pres3[r]);
}
if (r > R) {
res = res - ((pres1[r] - pres1[R]) + (pres2[r] - pres2[R]) +
(pres3[r] - pres3[R]));
}
if (l < L) {
res -= pres3[L - 1] - pres3[l - 1];
res = res + ((pres1[L - 1] - pres1[l - 1]) + (pres2[L - 1] - pres2[l - 1]));
}
if (l > L) {
res += pres3[l - 1] - pres3[L - 1];
res = res - ((pres1[l - 1] - pres1[L - 1]) + (pres2[l - 1] - pres2[L - 1]));
}
l = L, r = R;
res += Q[x].ret;
ans[id] = res;
}
//for (int x = 1; x <= m; x++) printf("%lld\n", ans[x]);
for (int x = 1; x <= m; x++) write(ans[x]), putchar('\n');
return 0;
}