Submission #1267387

#TimeUsernameProblemLanguageResultExecution timeMemory
1267387gayOlympic Bus (JOI20_ho_t4)C++20
0 / 100
1104 ms321028 KiB
#include <bits/stdc++.h>
#include <experimental/random>
#include <random>

using namespace std;
using ll = long long;
using ld = long double;

const ll INF = 2e18, MOD = 998244353;

void solve();

signed main() {
#ifdef LOCAL
    freopen("input.txt", "r", stdin);
    freopen("output.txt", "w", stdout);
#endif
    ios_base::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    ll q = 1;
    //cin >> q;
    while (q--) {
        solve();
    }
}

struct e {
    ll v, c, d, id;
};

ll n, m;

vector<vector<ll>> calc(ll st, vector<vector<e>>& g) {
    vector<vector<ll>> dp(n, vector<ll>(m + 1, INF));
    dp[st][m] = 0;
    set<pair<ll, ll>> dj;
    dj.insert({0, st});
    while (!empty(dj)) {
        ll w = dj.begin()->first;
        ll u = dj.begin()->second;
        dj.erase(dj.begin());
        for (auto [v, c, d, idx] : g[u]) {
            if (dp[v][idx] > w + c) {
                dj.erase({dp[v][idx], v});
                dp[v][idx] = w + c;
                dj.insert({dp[v][idx], v});
            }
        }
    }
    return dp;
}

vector<ll> get_path(vector<vector<e>>& g, ll s, ll t) {
    vector<pair<ll, ll>> pred(n, {-1, -1});
    vector<ll> dp(n, INF);
    dp[s] = 0;
    set<pair<ll, ll>> dj;
    dj.insert({0, s});
    while (!empty(dj)) {
        auto [w, u] = *dj.begin();
        dj.erase(dj.begin());
        for (auto [v, c, d, id] : g[u]) {
            if (dp[v] > w + c) {
                pred[v] = {u, id};
                dj.erase({dp[v], v});
                dp[v] = w + c;
                dj.insert({dp[v], v});
            }
        }
    }

    vector<ll> used(m);
    if (dp[t] == INF) {
        return used;
    }

    while (pred[t].second != -1) {
        used[pred[t].second] = 1;
        t = pred[t].first;
    }

    return used;
}

ll cost(vector<vector<e>>& g, ll s, ll t) {
    vector<ll> dp(n, INF);
    dp[s] = 0;
    set<pair<ll, ll>> dj;
    dj.insert({0, s});
    while (!empty(dj)) {
        auto [w, u] = *dj.begin();
        dj.erase(dj.begin());
        for (auto [v, c, d, id] : g[u]) {
            if (dp[v] > w + c) {
                dj.erase({dp[v], v});
                dp[v] = w + c;
                dj.insert({dp[v], v});
            }
        }
    }
    return dp[t];
}

void rev_ed(ll u, ll v, ll idx, vector<vector<e>>& g) {
    vector<e> x = g[u];
    g[u].clear();
    e hv{};
    for (auto ed : x) {
        if (ed.id != idx) {
            g[u].push_back(ed);
        } else {
            hv = ed;
        }
    }
    g[v].push_back({u, hv.c, hv.d, hv.id});
}

void solve() {
    cin >> n >> m;
    vector<vector<e>> g(n), rev_g(n);
    vector<e> edges;
    for (int i = 0; i < m; i++) {
        ll u, v, c, d;
        cin >> u >> v >> c >> d;
        u--, v--;
        edges.push_back({u, v, c, d});
        g[u].push_back({v, c, d, i});
        rev_g[v].push_back({u, c, d, i});
    }

    vector<vector<ll>> dp1 = calc(0, g);
    vector<vector<ll>> dp2 = calc(n - 1, rev_g);
    vector<vector<ll>> dp3 = calc(n - 1, g);
    vector<vector<ll>> dp4 = calc(0, rev_g);

    vector<ll> pt1 = get_path(g, 0, n - 1);
    vector<ll> pt2 = get_path(g, n - 1, 0);
    ll c1 = cost(g, 0, n - 1), c2 = cost(g, n - 1, 0);

    ll ans = c1 + c2;

    for (int i = 0; i < m; i++) {
        auto [u, v, w, d] = edges[i];
        rev_ed(u, v, i, g);

        ll p1 = c1, p2 = c2;
        if (pt1[i]) {
            p1 = cost(g, 0, n - 1);
        }
        if (pt2[i]) {
            p2 = cost(g, n - 1, 0);
        }

        ll cnt = INF;
        for (int pr1 = 0; pr1 <= m; pr1++) {
            if (pr1 != i) {
                cnt = min(cnt, dp1[v][pr1] + w);
            }
        }
        ll cnt2 = INF;
        for (int pr2 = 0; pr2 <= m; pr2++) {
            if (pr2 != i) {
                cnt2 = min(cnt2, dp2[u][pr2]);
            }
        }
        p1 = min(p1, cnt + cnt2);

        cnt = INF;
        for (int pr1 = 0; pr1 <= m; pr1++) {
            if (pr1 != i) {
                cnt = min(cnt, dp3[v][pr1] + w);
            }
        }
        cnt2 = INF;
        for (int pr2 = 0; pr2 <= m; pr2++) {
            if (pr2 != i) {
                cnt2 = min(cnt2, dp4[u][pr2]);
            }
        }
        p2 = min(p2, cnt + cnt2);

        ans = min(ans, p1 + p2 + d);

        rev_ed(v, u, i, g);
    }

    if (ans >= INF) {
        cout << -1;
        return;
    }

    cout << ans;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...