rt
一直WA on #25,昨天晚上到现在都没调出来。输出是这样的: expected 623989857, found -425284871
我觉得负数可能是精度的锅,但我 k 开了long long,字典树也开了35层,甚至 #define int long long了还是一样的WA。
求神仙们看一下到底是哪里出了锅/kk
但是我感觉这种题发求助是没有人会看的???
@Limit 神仙如果您对这个求助不感兴趣可以忽略这个艾特/kel
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<climits>
#define int long long
using namespace std;
const int MAXN=5e4+10,MAXALPHA=5e4*35+10,MODER=1e9+7;
int n,a[MAXN],bt[MAXN][35],mk[35],sum[MAXN][35]; //sum(i,j) 1..i 二进制第j位为1的数的个数
int sz,ch[MAXALPHA][2],cnt[MAXALPHA];
long long ans,k;
inline int mknode(){return (++sz);}
long long power(int n){
if(!n)return 1;
if(n==1)return 2;
long long ans=power(n/2);
ans=(ans*ans)%MODER;
if(n&1){
ans=(ans*2)%MODER;
}
return ans;
}
inline void insert_value(int idx){
int u=0;cnt[0]++;
for(int i=1;i<=32;i++){
int c=bt[idx][i];
if(!ch[u][c])ch[u][c]=mknode();
u=ch[u][c];
cnt[u]++;
}
}
int dfs1(int u,int idx,int r){
if(r==33){return cnt[u];}
int now = mk[r],c=bt[idx][r]^1;
int ans = 0;
if(!now){
//这一位是1或0都可以
if(ch[u][c]){
//如果是1,下面的恒成立
ans += cnt[ch[u][c]];
}
if(ch[u][c^1]){
ans += dfs1(ch[u][c^1],idx,r+1);
}
}else{
//必须1
if(ch[u][c]){
ans += dfs1(ch[u][c],idx,r+1);
}
}
return ans;
}
long long dfs2(int u,int idx,int r,long long& v){
if(r==33){return (((long long)a[idx]^v)*cnt[u] )%MODER;}
int c=bt[idx][r]^1,now=mk[r];
long long ans=0;
if(now){
if(ch[u][c]){
v=(v<<1)|c;
ans+=dfs2(ch[u][c],idx,r+1,v);
}
}else{
if(ch[u][c]){
int son=ch[u][c];
if(cnt[son]){
long long tmp = (v<<1)|c;
tmp<<=(32-r);
int l=0,r=0;
//a中二分tmp
l = lower_bound(a+1,a+1+n,tmp)-a;
r = l+cnt[son]-1;
//查询u xor a[l..r] 的答案
for(int j=0;j<32;j++){
long long tot=sum[r][j]-sum[l-1][j];
if(bt[idx][32-j])tot=(r-l+1)-tot;
ans = (ans + (1LL<<j)*tot%MODER)%MODER;
}
}
}
if(ch[u][c^1]){
v=(v<<1)|(c^1);
ans+=dfs2(ch[u][c^1],idx,r+1,v);
ans%=MODER;
}
}
return ans%MODER;
}
bool check(int mid){
long long cnt = 0;
int v=mid;
for(int i=32;i>=1;i--){
mk[i]=v&1;
v>>=1;
}
for(int i=1;i<=n;i++){
//查询xor(a[i],a[j]) >= mid的个数
cnt+=dfs1(0,i,1);
}
return (cnt>=k+k);
}
signed main(){
scanf("%lld%lld",&n,&k);
for(int i=1;i<=n;i++){
scanf("%lld",&a[i]);
}
sort(a+1,a+1+n);
for(int i=1;i<=n;i++){
int v=a[i];
for(int j=32;j>=1;j--){
bt[i][j]=v&1;
v>>=1;
}
for(int j=0;j<32;j++){
sum[i][j]=sum[i-1][j]+bt[i][32-j];
}
insert_value(i);
}
//O(n log^2 w)求第k大点对值
int l = 0,r = 1e9,kv = 0;
while(l<=r){
int mid=(l+r)>>1;
if(check(mid)){
kv=mid;
l=mid+1;
}else{
r=mid-1;
}
}
//O(n log w)求大于k的点对值
int v=kv+1;
for(int i=32;i>=1;i--){
mk[i]=v&1;v>>=1;
}
long long tot = 0;
for(int i=1;i<=n;i++){
tot += dfs1(0,i,1);
long long v=0;
ans = (ans + dfs2(0,i,1,v))%MODER;
}
//=kv的部分单独算
ans += ((k+k-tot)%MODER * kv)%MODER;
ans %= MODER;
ans = (ans*power(MODER-2))%MODER;
printf("%lld",ans);
return 0;
}