이 제출은 이전 버전의 oj.uz에서 채점하였습니다. 현재는 제출 당시와는 다른 서버에서 채점을 하기 때문에, 다시 제출하면 결과가 달라질 수도 있습니다.
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define pll pair<ll, ll>
#define inf LLONG_MAX/(ll)2
//nodes, edges, pass start, pass end, person start, person end
ll n, m, s, t, u, v;
vector<vector<pll>> adj;
//path.first -> dag from s to t, path.second -> dag from t to s
vector<vector<ll>> path[2];
//u dist, v dist
vector<ll> dist[2];
//dp[0][i] -> min cost of u going to node along path start -> i
//dp[1][i] -> min cost of v going to node along path i -> end
vector<ll> dp[2];
vector<ll> currDist; vector<vector<ll>> currPath;
vector<ll> getDist(ll src){
priority_queue<pll, vector<pll>, greater<pll>> trav;
vector<ll> dst(n); fill(dst.begin(), dst.end(), inf);
trav.push({0, src}); dst[src] = 0;
while (!trav.empty()){
pll curr = trav.top(); trav.pop();
if (dst[curr.second] != curr.first) continue;
for (pll elem : adj[curr.second]){
ll d = curr.first+elem.first;
if (d < dst[elem.second]){
dst[elem.second] = d;
trav.push({d, elem.second});
}
}
}
return dst;
}
vector<vector<ll>> getPath(ll src){
priority_queue<pll, vector<pll>, greater<pll>> trav; vector<vector<ll>> par(n);
vector<ll> dst(n); fill(dst.begin(), dst.end(), inf);
trav.push({0, src}); dst[src] = 0;
while (!trav.empty()){
pll curr = trav.top(); trav.pop();
if (dst[curr.second] != curr.first) continue;
for (pll elem : adj[curr.second]){
ll d = curr.first+elem.first;
if (d < dst[elem.second]){
dst[elem.second] = d;
par[elem.second].clear(); par[elem.second].push_back(curr.second);
trav.push({d, elem.second});
}
else if (d == dst[elem.second])
par[elem.second].push_back(curr.second);
}
}
return par;
}
//which path -> 0 for backtrack from end, 1 for bactrack from start
//which dist -> 0 for dist from u, 1 for dist from v
ll dfs(ll node, ll whichPath, ll whichDist){
if (dp[whichDist][node] != inf) return dp[whichDist][node];
ll lowest = dist[whichDist][node];
for (ll elem : path[whichPath][node])
lowest = min(lowest, dfs(elem, whichPath, whichDist));
dp[whichDist][node] = lowest;
return lowest;
}
void solve(){
ll ans = dist[0][v];
fill(dp[0].begin(), dp[0].end(), inf); fill(dp[1].begin(), dp[1].end(), inf);
dfs(s, 0, 0);
dfs(t, 1, 1);
for (int i = 0; i < n; i++) ans = min(ans, dp[0][i]+dp[1][i]);
//swap u and v
fill(dp[0].begin(), dp[0].end(), inf); fill(dp[1].begin(), dp[1].end(), inf);
dfs(s, 0, 1);
dfs(t, 1, 0);
for (int i = 0; i < n; i++) ans = min(ans, dp[0][i]+dp[1][i]);
cout<<ans;
}
int main(){
cin >> n >> m >> s >> t >> u >> v; s--, t--, u--, v--;
adj.resize(n); dp[0].resize(n); dp[1].resize(n);
for (int i = 0; i < m; i++){
ll a, b, w; cin >> a >> b >> w; a--, b--;
adj[a].push_back({w, b});
adj[b].push_back({w, a});
}
path[0] = getPath(t);
path[1] = getPath(s);
dist[0] = getDist(u);
dist[1] = getDist(v);
solve();
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |