样例第 3 个 testcase wa 了。
我的想法是 f(S)2 拆成 f(S)⋅f(S)。
f(s) 以二进制形式表示,每一位上的数依次是 b1,b2,⋯,b10。
最后乘积为 i=1∑10j=1∑10bibj。
而 E(f(S)2) 的求解,这样分类讨论,设当前考虑 x 位与 y 位相乘的期望:
#include <bits/stdc++.h>
using namespace std;
#define FOR(i,a,b) for(int i=a;i<=b;i++)
#define REP(i,a,b) for(int i=a;i>=b;i--)
#define pb push_back
#define mkpr make_pair
typedef long long ll;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
typedef vector<int> vi;
template<class T>
void ckmx(T& a,T b){
a=max(a,b);
}
template<class T>
void ckmn(T& a,T b){
a=min(a,b);
}
template<class T>
T gcd(T a,T b){
return !b?a:gcd(b,a%b);
}
template<class T>
T lcm(T a,T b){
return a/gcd(a,b)*b;
}
#define gc getchar()
#define eb emplace_back
#define pc putchar
#define ep empty()
#define fi first
#define se second
#define pln pc('\n');
template<class T>
void wrint(T x){
if(x<0){
x=-x;
pc('-');
}
if(x>=10){
wrint(x/10);
}
pc(x%10^48);
}
template<class T>
void wrintln(T x){
wrint(x);
pln
}
template<class T>
void read(T& x){
x=0;
int f=1;
char ch=gc;
while(!isdigit(ch)){
if(ch=='-')f=-1;
ch=gc;
}
while(isdigit(ch)){
x=(x<<1)+(x<<3)+(ch^48);
ch=gc;
}
x*=f;
}
const ll mod=1e9+7;
ll qpow(ll a,ll b){
ll ret=1;
for(;b;a=a*a%mod,b>>=1){
if(b&1)ret=ret*a%mod;
}
return ret;
}
ll inv(ll a){
return qpow(a,mod-2);
}
const int maxn=2e5+5;
int n;
int a[maxn];
ll _p[maxn],p[maxn];
ll dp[maxn][2][2];
ll dp2[maxn][2];
void solve(int id_of_test){
auto at = [&](int s,int i)->bool{
return (s&(1<<i));
};
read(n);
FOR(i,1,n){
read(a[i]);
}
ll inv14=inv(10000);
FOR(i,1,n){
read(_p[i]);
p[i]=_p[i]*inv14%mod;
}
ll ans=0;
FOR(bit,0,9){
FOR(bit2,0,9){
if(bit==bit2){
dp2[0][0]=1;
FOR(i,1,n){
if(!at(a[i],bit)){
dp2[i][0]=dp2[i-1][0];
dp2[i][1]=dp2[i-1][1];
continue;
}
dp2[i][0]=dp2[i-1][0]*(1ll-p[i])%mod;
dp2[i][1]=dp2[i-1][1]*(1ll-p[i])%mod;
dp2[i][0]+=dp2[i-1][1]*p[i]%mod;
dp2[i][1]+=dp2[i-1][0]*p[i]%mod;
dp2[i][0]%=mod;
dp2[i][1]%=mod;
}
ans+=(1ll<<bit)*(1ll<<bit)%mod*dp2[n][1]%mod;
ans%=mod;
continue;
}
dp[0][0][0]=1;
FOR(i,1,n){
FOR(k1,0,1){
FOR(k2,0,1){
ll b1notchg=1,b2notchg=1;
if(at(a[i],bit))b1notchg=(1ll-p[i]);
if(at(a[i],bit2))b2notchg=(1ll-p[i]);
dp[i][k1][k2]=dp[i-1][k1][k2]*b1notchg%mod*b2notchg%mod;
dp[i][k1][k2]+=dp[i-1][k1^1][k2]*(1ll-b1notchg)%mod*b2notchg%mod;
dp[i][k1][k2]+=dp[i-1][k1][k2^1]*b1notchg%mod*(1ll-b2notchg)%mod;
dp[i][k1][k2]+=dp[i-1][k1^1][k2^1]*(1ll-b1notchg)%mod*(1ll-b2notchg)%mod;
dp[i][k1][k2]%=mod;
}
}
}
ans+=dp[n][1][1]*(1ll<<bit)%mod*(1ll<<bit2)%mod;
ans%=mod;
}
}
printf("%lld\n",(ans+mod)%mod);
}
int main()
{
int T;
read(T);
FOR(_,1,T){
solve(_);
}
return 0;
}
求大佬们帮助。