“如果两个数出现次数都 ≥2 就暴力,否则直接算”可以在 ~3s 通过。这显然不合理,21N 个 1 和 2 个 2∼21N 然后 1 和每个询问一遍就卡掉了。
#include <bits/stdc++.h>
#define rep(i,n) for(int i=0,del##i##verme=int(n);i<del##i##verme;++i)
#define rep1(i,n) for(int i=1,parano##i##a=int(n);i<=parano##i##a;++i)
#define pb push_back
#define mp make_pair
typedef long long ll;
using namespace std;
const int N = 300005;
int n, q, a[N], b[N], lst[N];
map<pair<int, int>, ll> M;
int c[N];
vector<int> occ[N];
vector<int> vc;
ll val[2 * N];
int tag[2 * N];
int tg;
ll solve_brute(int x, int y)
{
vc.clear();
int l = 0, r = 0;
while(l < int(occ[x].size()) && r < int(occ[y].size()))
{
if(occ[x][l] > occ[y][r]) vc.pb(occ[y][r++]);
else vc.pb(occ[x][l++]);
}
while(l < int(occ[x].size())) vc.pb(occ[x][l++]);
while(r < int(occ[y].size())) vc.pb(occ[y][r++]);
int cur = N;
ll vt = 0, ans = -1145141919810;
++tg;
tag[cur] = tg;
val[cur] = vt;
for(int z : vc)
{
if(a[z] == x) ++cur;
else --cur;
vt += b[z];
if(tag[cur] == tg)
{
ans = max(ans, vt - val[cur]);
val[cur] = min(val[cur], vt);
}
else
{
tag[cur] = tg;
val[cur] = vt;
}
}
return ans;
}
ll query(int x, int y)
{
if(c[x] > c[y] || (c[x] == c[y] && x > y)) swap(x, y);
if(M.count(mp(x, y))) return M[mp(x, y)];
if(c[x] >= 2) return M[mp(x, y)] = solve_brute(x, y);
int z = occ[x][0];
vector<int>::iterator itr = lower_bound(occ[y].begin(), occ[y].end(), z);
int ans = -2145141919;
if(itr != occ[y].begin()) ans = max(ans, b[z] + b[*prev(itr)]);
if(itr != occ[y].end()) ans = max(ans, b[z] + b[*itr]);
return ans;
}
int main()
{
#ifndef DEBUG
ios_base::sync_with_stdio(false);
cin.tie(0);
#endif
cin >> n >> q;
rep1(i, n)
{
cin >> a[i];
occ[a[i]].pb(i);
}
rep1(i, n) cin >> b[i];
rep1(i, n) c[i] = int(occ[i].size());
rep1(i, n) if(c[i] >= 2)
{
rep(j, c[i] - 1) lst[occ[i][j + 1]] = occ[i][j];
}
while(q--)
{
int x, y;
cin >> x >> y;
cout << query(x, y) << '\n';
}
return 0;
}