代码如下:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod = 998244353;
ll m,n;
ll a[1000010];
ll v[1000010];
ll b[1000010],tot;
ll s[1000010];
ll read(){
register ll x=0;register char ch=getchar();
while(ch<'0'||'9'<ch) ch=getchar();
while('0'<=ch&&ch<='9') x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
return x;
}
ll bj(ll x,ll y){
ll ret = tot + 1;
ll l = 1,r = tot;
while(l <= r){
ll mid = (l + r)>>1;
if(x / a[mid] <= y){
ret = mid;
r = mid - 1;
}
else{
l = mid + 1;
}
}
return ret;
}
int main(){
n = read();
m = read();
for(ll i = 1;i <= n;i++){
a[i] = read();
}
sort(a + 1,a + n + 1);
for(ll i = 1;i <= n;i++){
if(a[i - 1] != a[i]){
b[++tot] = a[i];
}
++v[tot];
}
unique(a + 1,a + n + 1);
for(ll i = 1;i <= tot;i++){
s[i] += s[i - 1] + v[i];
}
ll ans = 0;
for(ll i = 1;i <= tot;i++){
ll x = b[i];
if(x > m){
break;
}
ll ai = m / x;
ll mx = bj(ai,0) - 1;
for(ll j = 1;j * j <= tot;j++){
ll mn = bj(ai,j);
ans += (j * v[i] % mod) * (s[mx] - s[mn - 1]) % mod;
ans %= mod;
mx = mn - 1;
if(mx < 1){
break;
}
}
for(ll j = mx;j >= 1;j--){
ll y = b[j];
if(x * y > m){
break;
}
ans += m / x / y * v[i] % mod * v[j] % mod;
ans %= mod;
}
}
printf("%lld",ans);
return 0;
}
思路方面基本延循what_can_I_do大佬的题解,本人并未测试过该题解是否可行。
另据该篇讨论 ,使用upper_bound要比使用手写二分快,是否真实不知。但本人并未写成upper_bound代码(百度上查到的好像必须使用vector)。