RT。先不用管TLE,先把WA解决。 目前过不了样例。
#include <bits/stdc++.h>
using namespace std;
#define LL long long
LL Pow(LL a, LL b, LL p) {
LL ans = 1;
a = (a % p + p) % p;
for (; b; b >>= 1) {
if (b & 1) ans = (1ull * a * ans) % p;
a = (1ull * a * a) % p;
}
return ans;
}
LL calc(LL n, LL x, LL P) {
if (!n) return 1;
LL s = 1;
for (LL i = 1; i <= P; i++)
if (i % x) s = s * i % P;
s = Pow(s, n / P, P);
for (LL i = n / P * P + 1; i <= n; i++)
if (i % x) s = i % P * s % P;
return s * calc(n / x, x, P) % P;
}
LL inverse(LL x, LL y) {
return Pow(x, y - 2, y);
}
LL multilucas(LL m, LL n, LL x, LL P) {
LL cnt = 0;
for (LL i = m; i; i /= x) cnt += i / x;
for (LL i = n; i; i /= x) cnt -= i / x;
for (LL i = m - n; i; i /= x) cnt -= i / x;
return Pow(x, cnt, P) % P * calc(m, x, P) % P * inverse(calc(n, x, P), P) % P * inverse(calc(m - n, x, P), P) % P;
}
void exgcd(LL a, LL b, LL& x, LL& y) {
if (b == 0) {
x = 1, y = 0;
return;
}
exgcd(b, a % b, y, x);
y -= a / b * x;
}
LL CRT(LL k, LL* a, LL* r) {
LL n = 1, ans = 0;
for (LL i = 1; i <= k; i++) n = n * r[i];
for (LL i = 1; i <= k; i++) {
LL m = n / r[i], b, y;
exgcd(m, r[i], b, y);
ans = (ans + a[i] * m * b % n) % n;
}
return (ans % n + n) % n;
}
LL exlucas(LL m, LL n, LL P) {
LL cnt = 0;
LL p[30], a[30];
for (LL i = 2; i * i <= P; i++) {
if (P % i == 0) {
p[++cnt] = 1;
while (P % i == 0) p[cnt] = p[cnt] * i, P /= i;
a[cnt] = multilucas(m, n, i, p[cnt]);
}
}
if (P > 1) p[++cnt] = P, a[cnt] = multilucas(m, n, P, P);
return CRT(cnt, a, p);
}
// These above are templates
LL T, p, n, n1, n2, m, f[100000], a[100000];
LL work(int x) {
LL cnt = 0;
for (int i = 1; i <= n1 + n2; i++) {
cnt += (i > n1) ? (a[i] - 1) : ((x & (1 << (i - 1))) ? a[i] : 0);
}
return exlucas(m - cnt, n - 1, p);
}
int cnt_bit[20];
int main() {
// init begin
cnt_bit[0] = 0;
for (int i = 1; i <= 19; i++) cnt_bit[i] = cnt_bit[i - (i & -i)] + 1;
// init end
scanf("%lld%lld", &T, &p);
while (T--) {
memset(f, 0, sizeof(f));
scanf("%lld%lld%lld%lld", &n, &n1, &n2, &m);
for (int i = 1; i <= n1 + n2; i++) scanf("%lld", &a[i]);
for (int i = 0; i < (1 << n1); i++) {
f[cnt_bit[i]] += work(i);
}
LL ans = 0; int flag = 1;
for (int i = 0; i <= n1; i++) {
ans = ans + flag * f[i];
flag = (flag == 1) ? -1 : 1;
}
printf("%lld\n", ans);
}
return 0;
}