Submission #1209772

#TimeUsernameProblemLanguageResultExecution timeMemory
1209772dima2101Olympic Bus (JOI20_ho_t4)C++20
100 / 100
38 ms6472 KiB
#include <bits/stdc++.h>
#define int long long

struct Node
{
    int stop;
    int cost_go;
    int cost_rev;
    int start;
    int ind;

    Node(int start, int stop, int cost_go, int cost_rev, int ind) : start(start),
                                                                    stop(stop), cost_go(cost_go),
                                                                    cost_rev(cost_rev), ind(ind) {};
    Node()
    {
        start = -1;
    }
};

std::vector<std::vector<int>> floyd(std::vector<std::vector<int>> g, int n)
{
    for (int k = 0; k < n; k++)
    {
        for (int i = 0; i < n; i++)
        {
            for (int j = 0; j < n; j++)
            {
                if (g[i][j] > g[i][k] + g[k][j])
                {
                    g[i][j] = g[i][k] + g[k][j];
                }
            }
        }
    }
    return g;
}

std::vector<Node> dij(std::vector<std::vector<Node>> &g, int start, int stop, int n)
{
    std::vector<int> dist(n, 1e18);
    std::vector<Node> prev(n);
    dist[start] = 0;
    std::set<std::pair<int, int>> s;
    s.insert({dist[start], start});
    while (s.size() > 0)
    {
        int v = s.begin()->second;
        s.erase(s.begin());

        for (auto u : g[v])
        {
            if (dist[u.stop] > dist[v] + u.cost_go)
            {
                s.erase({dist[u.stop], u.stop});
                dist[u.stop] = dist[v] + u.cost_go;
                prev[u.stop] = u;
                s.insert({dist[u.stop], u.stop});
            }
        }
    }
    int now = stop;
    std::vector<Node> ans;
    while (prev[now].start >= 0)
    {
        ans.push_back(prev[now]);
        now = prev[now].start;
    }
    std::reverse(ans.begin(), ans.end());
    return ans;
}

signed
main()
{
    std::ios::sync_with_stdio(false);
    std::cin.tie(0);

    int n, m;
    std::cin >> n >> m;

    std::vector<Node> all(m);
    std::vector<std::vector<Node>> g(n);
    for (int i = 0; i < m; i++)
    {
        int a, b, c, p;
        std::cin >> a >> b >> c >> p;
        a--, b--;
        all[i] = Node(a, b, c, p, i);
        g[a].push_back(Node(a, b, c, p, i));
    }

    std::vector<Node> from_1_n = dij(g, 0, n - 1, n);
    std::set<int> help1;
    std::set<int> helpn;

    std::vector<Node> from_n_1 = dij(g, n - 1, 0, n);
    std::vector<int> vert_1_n;
    std::vector<int> vert_n_1;
    for (auto i : from_1_n)
    {
        vert_1_n.push_back(i.start);
        help1.insert(i.ind);
    }
    vert_1_n.push_back(n - 1);
    for (auto i : from_n_1)
    {
        vert_n_1.push_back(i.start);
        helpn.insert(i.ind);
    }
    vert_n_1.push_back(0);
    std::vector<std::vector<int>> best(n, std::vector<int>(n, 1e14));
    for (int v = 0; v < n; v++)
    {
        for (auto u : g[v])
        {
            best[v][u.stop] = std::min(best[v][u.stop], u.cost_go);
        }
    }
    for (int i = 0; i < n; i++)
    {
        best[i][i] = 0;
    }

    best = floyd(best, n);

    std::vector<std::vector<int>> best_1_n(n, std::vector<int>(n, 1e14));
    std::vector<std::vector<int>> best_n_1(n, std::vector<int>(n, 1e14));

    for (int v = 0; v < n; v++)
    {
        for (auto u : g[v])
        {
            if (help1.find(u.ind) == help1.end())
            {
                best_1_n[v][u.stop] = std::min(best_1_n[v][u.stop], u.cost_go);
            }
            if (helpn.find(u.ind) == helpn.end())
            {
                best_n_1[v][u.stop] = std::min(best_n_1[v][u.stop], u.cost_go);
            }
        }
    }

    for (int i = 0; i < n; i++)
    {
        best_1_n[i][i] = 0;
        best_n_1[i][i] = 0;
    }

    best_1_n = floyd(best_1_n, n);
    best_n_1 = floyd(best_n_1, n);

    int min = best[0][n - 1] + best[n - 1][0];
    for (int i = 0; i < m; i++)
    {
        int min_1_n = 1e13;
        if (help1.find(i) != help1.end())
        {
            int help_ind = -1;
            int cnt = 0;
            for (auto j : from_1_n)
            {
                if (i == j.ind)
                {
                    help_ind = cnt;
                }
                cnt++;
            }
            assert(help_ind != -1);
            for (int l = 0; l <= help_ind; l++)
            {
                for (int r = help_ind + 1; r < vert_1_n.size(); r++)
                {
                    // std::cout << l << ' ' << r << ' ' << best_1_n[vert_1_n[l]][vert_1_n[r]] << std::endl;
                    min_1_n = std::min(min_1_n, best[0][vert_1_n[l]] + best[vert_1_n[r]][n - 1] +
                                                    best_1_n[vert_1_n[l]][vert_1_n[r]]);
                }
            }
        }
        else
        {
            min_1_n = best[0][all[i].stop] + best[all[i].start][n - 1] + all[i].cost_go;
            min_1_n = std::min(min_1_n, best[0][n - 1]);
        }

        int min_n_1 = 1e13;

        // std::cout << all[i].start << ' ' << all[i].stop << std::endl;
        // std::cout << best[n - 1][all[i].stop] << ' ' << best[all[i].start][0] << std::endl;
        if (helpn.find(i) != helpn.end())
        {
            int help_ind = -1;
            int cnt = 0;
            for (auto j : from_n_1)
            {
                if (i == j.ind)
                {
                    help_ind = cnt;
                }
                cnt++;
            }

            assert(help_ind != -1);

            for (int l = 0; l <= help_ind; l++)
            {
                for (int r = help_ind + 1; r < vert_n_1.size(); r++)
                {
                    min_n_1 = std::min(min_n_1, best[n - 1][vert_n_1[l]] + best[vert_n_1[r]][0] +
                                                    best_n_1[vert_n_1[l]][vert_n_1[r]]);
                }
            }
        }
        else
        {
            min_n_1 = best[n - 1][all[i].stop] + best[all[i].start][0] + all[i].cost_go;

            min_n_1 = std::min(min_n_1, best[n - 1][0]);
        }
        // std::cout << min_1_n << ' ' << ' ' << min_n_1 << ' ' << all[i].cost_rev << std::endl;
        //   std::cout << all[i].start << ' ' << all[i].stop << ' ' << min_1_n << ' ' << min_n_1 << std::endl;
        min = std::min(min, min_1_n + min_n_1 + all[i].cost_rev);
    }
    if (min > 1e12)
    {
        std::cout << -1 << std::endl;
        return 0;
    }
    std::cout << min << std::endl;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...