如题。
#include<bits/stdc++.h>
#define int long long
using namespace std;
int n,m,u,v,w,s1,s2,t1,t2,mp[1501][1501],mp1[1501][1501],dis[1501],mp2[1501][1501],rd[1501],f[1501],ans;
bool vis[1501],iq[1501];
vector<int>pre[1501];
vector<pair<int,int>>e[1501];
queue<int>qv;
priority_queue<pair<int,int>>q;
inline void dij(int s){
for(int i = 1;i <= n;i ++) pre[i].clear();
memset(dis,0x3f,sizeof(dis));
memset(vis,0,sizeof(vis));
dis[s] = 0;
q.push({0,s});
while(q.size()){
int u = q.top().second;
q.pop();
if(vis[u]) continue;
vis[u] = true;
for(auto [v,w] : e[u]){
if(dis[v] > dis[u] + w){
dis[v] = dis[u] + w;
pre[v].clear();
pre[v].push_back(u);
q.push({-dis[v],v});
}
else if(dis[v] == dis[u] + w) pre[v].push_back(u);
}
}
return;
}
signed main(){
scanf("%lld%lld%lld%lld%lld%lld",&n,&m,&s1,&t1,&s2,&t2);
for(int i = 1;i <= m;i ++){
scanf("%lld%lld%lld",&u,&v,&w);
mp[u][v] = mp[v][u] = w;
e[u].push_back({v,w}),e[v].push_back({u,w});
}
dij(s1);
qv.push(t1);
while(qv.size()){
int u = qv.front();
qv.pop();
for(auto v : pre[u]){
if(!iq[v]){
qv.push(v);
iq[v] = true;
}
mp1[v][u] = mp[v][u];
}
}
dij(s2);
memset(iq,0,sizeof(iq));
qv.push(t2);
while(qv.size()){
int u = qv.front();
qv.pop();
for(auto v : pre[u]){
if(!iq[v]){
qv.push(v);
iq[v] = true;
}
mp2[v][u] = mp[v][u];
}
}
for(int i = 1;i <= n;i ++) e[i].clear();
for(u = 1;u <= n;u ++){
for(v = 1;v <= n;v ++){
if(mp1[u][v] && mp2[u][v]){
e[u].push_back({v,mp[u][v]});
rd[v] ++;
}
}
}
for(int i = 1;i <= n;i ++) if(!rd[i]) qv.push(i);
while(qv.size()){
int u = qv.front();
qv.pop();
for(auto [v,w] : e[u]){
f[v] = max(f[v],f[u] + w);
rd[v] --;
if(!rd[v]) qv.push(v);
}
}
for(int i = 1;i <= n;i ++){
ans = max(ans,f[v]);
f[i] = 0,rd[i] = 0;
e[i].clear();
}
for(u = 1;u <= n;u ++){
for(v = 1;v <= n;v ++){
if(mp1[u][v] && mp2[v][u]){
e[u].push_back({v,mp[u][v]});
rd[v] ++;
}
}
}
for(int i = 1;i <= n;i ++) if(!rd[i]) qv.push(i);
while(qv.size()){
int u = qv.front();
qv.pop();
for(auto [v,w] : e[u]){
f[v] = max(f[v],f[u] + w);
rd[v] --;
if(!rd[v]) qv.push(v);
}
}
for(int i = 1;i <= n;i ++) ans = max(ans,f[i]);
printf("%lld",ans);
return 0;
}