参考这篇题解写的
#include <bits/stdc++.h>
#define int long long
#define int1 __int128
#define uint unsigned long long
#define lowbit(x) (x & -x)
using namespace std;
namespace fastio
{
const int bufl = 1 << 20;
struct IN
{
FILE *IT = stdin;
char ibuf[bufl], *is = ibuf, *it = ibuf;
inline char getChar()
{
if (is == it)
{
it = (is = ibuf) + fread(ibuf, 1, bufl, IT);
if (is == it)
return EOF;
}
return *is++;
}
IN &operator>>(int &a)
{
a = 0;
int b = 0, c = getChar();
while (c < 48 || c > 57)
b ^= (c == 45), c = getChar();
while (c >= 48 && c <= 57)
a = (a << 1) + (a << 3) + c - 48, c = getChar();
if (b)
a = -a;
return *this;
}
} fin;
}
// #define cin fastio::fin
const int DM[8][2] = {0, 1, 1, 0, 0, -1, -1, 0, 1, 1, 1, -1, -1, 1, -1, -1};
const int HM = 5644863343;
const int HB = 131;
const int HI = 1e5 + 3;
const int N = 100001;
void write(int1 x)
{
if (x > 9)
write(x / 10);
putchar(x % 10 | 48);
}
void exgcd(int a, int b, int1 &x, int1 &y)
{
if (!b)
{
x = 1;
y = 0;
return;
}
exgcd(b, a % b, x, y);
int1 t = x;
x = y;
y = t - (a / b) * y;
return;
}
int1 inv(int a, int p)
{
int1 x, y;
exgcd(a, p, x, y);
x %= p;
if (x < 0)
{
x += p;
}
return x;
}
int1 qpow(int x, int k)
{
if (k == 0)
{
return 1;
}
int1 tmp = qpow(x, k >> 1);
tmp *= tmp;
if (k & 1)
{
return x * tmp;
}
return tmp;
}
int1 niu(int a, int p)
{
a /= p;
int1 res = a;
while (a)
{
a /= p;
res += a;
}
return res;
}
int1 wilson(int a, int p, int pc)
{
int1 ans = 1;
vector<int1> f(pc + 5);
f[0] = 1;
for (int i = 1; i < pc; i++)
{
if (i % p)
{
f[i] = f[i - 1] * i % pc;
}
else
{
f[i] = f[i - 1];
}
}
bool fl = (p != 2 || pc <= 4);
while (a > 1)
{
if (fl & (a / pc))
{
ans = pc - ans;
}
ans = ans * f[a % pc] % pc;
a /= p;
}
return ans;
}
int1 crt(vector<int1> a, vector<int1> p)
{
int1 ans = 0, prd = 1;
for (auto v : p)
{
prd *= v;
}
for (int i = 0; i < a.size(); i++)
{
int1 k = prd / p[i], x, y;
// write(k);
// cout<<' ';
// write(a[i]);
// cout<<'\n';
exgcd(k, p[i], x, y);
ans = (ans + k * a[i] * x % prd) % prd;
}
return (ans % prd + prd) % prd;
}
int1 exlucas(int n, int m, int1 MOD)
{
vector<int1> p, pc;
int1 tp = MOD;
for (int i = 2; i * i <= tp; i++)
{
if (tp % i == 0)
{
p.push_back(i);
int res = 1;
while (tp % i == 0)
{
res *= i;
tp /= i;
}
pc.push_back(res);
}
}
if (tp)
{
p.push_back(tp);
pc.push_back(tp);
}
vector<int1> a;
// cout << "check\n";
for (int i = 0; i < p.size(); i++)
{
int1 nw = wilson(n, p[i], pc[i]), mw = wilson(m, p[i], pc[i]), nmw = wilson(n - m, p[i], pc[i]);
a.push_back(nw * inv(mw * nmw % pc[i], pc[i]) * qpow(pc[i], niu(n, p[i]) - niu(m, p[i]) - niu(n - m, p[i])));
}
return crt(a, pc);
}
inline int read()
{
char c = getchar();
int x = 0, s = 1;
while (c < '0' || c > '9')
{
if (c == '-')
s = -1;
c = getchar();
}
while (c >= '0' && c <= '9')
{
x = x * 10 + c - '0';
c = getchar();
}
return x * s;
}
signed main()
{
int1 n, m, p;
n = read();
m = read();
p = read();
write(exlucas(n, m, p));
// cout << '\n';
// write(wilson(5, 3, 3));
// cout << ;
}