求助程序瓶颈玄关
  • 板块学术版
  • 楼主Rindong
  • 当前回复3
  • 已保存回复3
  • 发布时间2024/11/1 12:44
  • 上次更新2024/11/1 16:58:44
查看原帖
求助程序瓶颈玄关
967972
Rindong楼主2024/11/1 12:44

事情是这样的,昨天刚学莫队二次离线,在写 P5501 。然后写了份代码,我确定我的复杂度是正确的(代码放在最后)并且几乎用了能用的卡常手段,但是还是 TLE record,求助大佬分析程序瓶颈。

自己简单进行分析了一下,判断瓶颈出现在这里:

for (int i = 1; i <= n; i++) {
	//瓶颈不在 insert,注释过 insert函数和后面的遍历
	//发现是 后面那个遍历最慢。
	insert(p1, b1, arr[i], 1);
	insert(p2, b2, arr[i], arr[i]);
	for (int ind = head[i]; ind; ind = nex[ind]) {
		Offline it = offline[ind];
		int l = it.l, r = it.r, id = it.id;
		ll d = it.d;
		//我曾统计过这个循环的运行次数,是跟题解一样的,甚至比
		//题解的次数要少,而且 query 函数的复杂度是 O(1)
		//但是我如果把这两句 query 去掉,就运行的飞快,这是为什么呢?
		//难道是关于类型转换?关于cache命中?求dalao优化(已失去所有力气和手段)
		for (int x = l; x <= r; x++) {
			Q[id].ret -= query(p1, b1, 1, arr[x] - 1) * arr[x] * d;
			Q[id].ret -= query(p2, b2, arr[x] + 1, MAX_M) * d;
		}
	}
}
/* insert 与 query 的实现
void insert(ll _in[MAX_B][MAX_B], ll _out[MAX_B], int ind, ll val) {
	int bb = (ind - 1) / zylen + 1;
	for (int x = (ind - 1) % zylen + 1; x <= zylen; x++) _in[bb][x] += val;
	for (int b = bb; b <= bcnt; b++) _out[b] += val;
}
ll query(ll _in[MAX_B][MAX_B], ll _out[MAX_B], int l, int r) {
	ll ret = 0;
	if (l > r) return 0;
	int aa = (l - 1) / zylen + 1, bb = (r - 1) / zylen + 1;
	if (aa == bb) return _in[bb][(r - 1) % zylen + 1] - _in[bb][(l - 1) % zylen];
	ret += _out[bb - 1] - _out[aa];
	ret += _in[bb][(r - 1) % zylen + 1];
	ret += _in[aa][zylen] - _in[aa][(l - 1) % zylen];
	return ret;
}*/

整体代码

#include <iostream>
#include <cmath>
#include <algorithm>
#include <cstring>
#include <vector>
using namespace std;
#define MAX_N 500010
#define MAX_M 100000
#define MAX_B 333
#define ll long long
inline int read() {
	int ret = 0;
	char ch = getchar();
	while (ch < '0' || '9' < ch) ch = getchar();
	while ('0' <= ch && ch <= '9')
		ret = ret * 10 + ch - '0', ch = getchar();
	return ret;
}
void write(ll x) {
	if (x > 9) write(x / 10);
	putchar(x % 10 + '0');
}
int n, m, len;
int get_blo(int i) { return i / len; }
struct Question {
	int l, r, id;
	ll ret;
	bool operator< (const Question& other) const {
		int bl = get_blo(l), br = get_blo(other.l);
		if (bl != br) return bl < br;
		if (bl % 2) return r < other.r;
		return r > other.r;
	}
} Q[MAX_N];
struct Offline {
	int id, l, r;
	ll d;
};
int idle = 1;
int head[MAX_N] = { 0 }, nex[MAX_N * 2] = { 0 };
Offline offline[MAX_N * 2];
void add(int x, Offline off) {
	nex[idle] = head[x];
	head[x] = idle;
	offline[idle++] = off;
}
ll arr[MAX_N] = { 0 };
//整块前缀和-普通值域 & 总和值域
ll b1[MAX_B] = { 0 }, b2[MAX_B] = { 0 };
//块内前缀和-普通值域 & 总和值域
ll p1[MAX_B][MAX_B] = { 0 }, p2[MAX_B][MAX_B] = { 0 };
int bcnt, zylen;
void insert(ll _in[MAX_B][MAX_B], ll _out[MAX_B], int ind, ll val) {
	int bb = (ind - 1) / zylen + 1;
	for (int x = (ind - 1) % zylen + 1; x <= zylen; x++) _in[bb][x] += val;
	for (int b = bb; b <= bcnt; b++) _out[b] += val;
}
ll query(ll _in[MAX_B][MAX_B], ll _out[MAX_B], int l, int r) {
	ll ret = 0;
	if (l > r) return 0;
	int aa = (l - 1) / zylen + 1, bb = (r - 1) / zylen + 1;
	if (aa == bb) return _in[bb][(r - 1) % zylen + 1] - _in[bb][(l - 1) % zylen];
	ret += _out[bb - 1] - _out[aa];
	ret += _in[bb][(r - 1) % zylen + 1];
	ret += _in[aa][zylen] - _in[aa][(l - 1) % zylen];
	return ret;
}
ll _tr1[MAX_M + 10] = { 0 };
ll _tr2[MAX_M + 10] = { 0 };
int lowbit(int i) { return i & (-i); }
void tr_insert(ll tr[], int ind, ll val) {
	while (ind <= MAX_M) {
		tr[ind] += val;
		ind += lowbit(ind);
	}
}
ll tr_query(ll tr[], int ind) {
	ll ret = 0;
	while (ind) {
		ret += tr[ind];
		ind -= lowbit(ind);
	}
	return ret;
}
ll pres1[MAX_N] = { 0 }; //比自己小的数数量之和
ll pres2[MAX_N] = { 0 }; //比自己大的数之和
ll pres3[MAX_N] = { 0 }; //arr[i]前缀和
ll ans[MAX_N] = { 0 };
int main() {
	zylen = bcnt = 320;
	n = read(), m = read();
	for (int x = 1; x <= n; x++) arr[x] = read();
	for (int x = 1; x <= n; x++) {
		pres1[x] = pres1[x - 1] + tr_query(_tr1, arr[x] - 1) * arr[x];
		pres2[x] = pres2[x - 1] + (tr_query(_tr2, MAX_M) - tr_query(_tr2, arr[x]));
		pres3[x] = pres3[x - 1] + arr[x];
		tr_insert(_tr1, arr[x], 1);
		tr_insert(_tr2, arr[x], arr[x]);
	}
	len = sqrt(n);
	for (int x = 1; x <= m; x++) {
		int l = read(), r = read();
		Q[x] = { l, r, x, 0 };
	}
	sort(Q + 1, Q + 1 + m);
	for (int x = 1, l = 1, r = 0; x <= m; x++) {
		int L = Q[x].l, R = Q[x].r;
		if (r < R) add(l - 1, { x, r + 1, R, 1 });
		if (r > R) add(l - 1, { x, R + 1, r, -1 });
		r = R;
		if (l > L) add(r, { x, L, l - 1, -1 });
		if (l < L) add(r, { x, l, L - 1, 1 });
		l = L;
	}
	for (int i = 1; i <= n; i++) {
		insert(p1, b1, arr[i], 1);
		insert(p2, b2, arr[i], arr[i]);
		//用 vector也试过了一样TLE
		for (int ind = head[i]; ind; ind = nex[ind]) {
			Offline it = offline[ind];
			int l = it.l, r = it.r, id = it.id;
			ll d = it.d;
			for (int x = l; x <= r; x++) {
				Q[id].ret -= query(p1, b1, 1, arr[x] - 1) * arr[x] * d;
				Q[id].ret -= query(p2, b2, arr[x] + 1, MAX_M) * d;
			}
		}
	}
	ll res = 0;
	for (int x = 1, l = 1, r = 0; x <= m; x++) {
		int L = Q[x].l, R = Q[x].r, id = Q[x].id;
		if (r < R) {
			res = res + (pres1[R] - pres1[r]) + (pres2[R] - pres2[r])
				+ (pres3[R] - pres3[r]);
		}
		if (r > R) {
			res = res - ((pres1[r] - pres1[R]) + (pres2[r] - pres2[R]) +
				(pres3[r] - pres3[R]));
		}
		if (l < L) {
			res -= pres3[L - 1] - pres3[l - 1];
			res = res + ((pres1[L - 1] - pres1[l - 1]) + (pres2[L - 1] - pres2[l - 1]));
		}
		if (l > L) {
			res += pres3[l - 1] - pres3[L - 1];
			res = res - ((pres1[l - 1] - pres1[L - 1]) + (pres2[l - 1] - pres2[L - 1]));
		}
		l = L, r = R;
		res += Q[x].ret;
		ans[id] = res;
	}
	//for (int x = 1; x <= m; x++) printf("%lld\n", ans[x]);
	for (int x = 1; x <= m; x++) write(ans[x]), putchar('\n');
	return 0;
}
2024/11/1 12:44
加载中...