见代码注释
#include <bits/stdc++.h>
using namespace std;
const int N = 10005;
int n, m;
vector<int> vec[N];
int dfn[N], st[N], tp, scc[N], cnt, low[N], sc;
int sz[N];
bool inst[N];
int out[N];
void tarjan(int u) {
dfn[u] = ++ cnt;
low[u] = dfn[u];
st[++ tp] = u;
inst[u] = 1;
for(int v : vec[u]) {
if(!dfn[v]) {
tarjan(v);
low[u] = min(low[u], low[v]);
} else if(inst[v]) {
low[u] = min(low[u], dfn[v]);
}
}
if(dfn[u] == low[u]) {
sc ++;
//--------------------------------
while(st[tp] != u) {
scc[st[tp]] = sc;
sz[sc] ++;
tp --;
}
scc[u] = sc;
sz[sc] ++;
tp --;
}
//--------------------------------
// 这一段忘记了 inst[st[tp]] = 0;
// 依然可以过
//--------------------------------
return ;
}
int main() {
cin >> n >> m;
for(int i = 1; i <= m; i ++) {
int a, b;
cin >> a >> b;
vec[a].push_back(b);
}
for(int i = 1; i <= n; i ++) {
if(!dfn[i]) {
tarjan(i);
}
}
for(int i = 1; i <= n; i ++) {
for(int j : vec[i]) {
if(scc[i] == scc[j]) continue;
out[scc[i]] ++;
}
}
bool f = 0;
int ans = 0;
for(int i = 1; i <= sc; i ++) {
if(out[i] == 0) {
if(f) {
ans = 0;
break;
}
f = 1;
ans = sz[i];
}
}
cout << ans << "\n";
return 0;
}