Submission #1185117

#TimeUsernameProblemLanguageResultExecution timeMemory
1185117qwushaCities (BOI16_cities)C++20
74 / 100
1117 ms40568 KiB
#include <bits/stdc++.h>
using namespace std;
/*
#pragma GCC optimize("O3")
#include <bitset>
#pragma GCC target("avx2")*/
#define fi first
#define se second
#define int long long
typedef long long ll;
typedef long double ld;
mt19937 rnd(chrono::high_resolution_clock::now().time_since_epoch().count());
int inf = 1e15;

int n, k, m;
vector<vector<pair<int,int>>> g;


struct dsu {
    vector<int> par;
    vector<int> sz;
    void init(int siz) {
        par.resize(siz);
        sz.resize(siz);
        for (int i = 0; i < siz; i++) {
            par[i] = i;
            sz[i] = 1;
        }
    }
    int get_par(int v) {
        if (par[v] == v)
            return v;
        int ans = get_par(par[v]);
        par[v] = ans;
        return ans;
    }
    void merg(int v, int u) {
        v = get_par(v);
        u = get_par(u);
        if (v == u)
            return;
        if (sz[u] > sz[v]) {
            swap(v, u);
        }
        sz[v] += sz[u];
        par[u] =  v;
    }
};


pair<vector<int>, vector<int>> dijktra(int st) {
    vector<int> dist(n, inf);
    vector<int> par(n, -1);
    dist[st] = 0;
    set<pair<int, int>> q = {{0, st}};
    while(!q.empty()) {
        auto pa = *q.begin();
        q.erase(q.begin());
        int d = pa.fi, v = pa.se;
        for (auto [u, w] : g[v]) {
            if (dist[u] > dist[v] + w) {
                par[u] = v;
                q.erase({dist[u], u});
                dist[u] = dist[v] + w;
                q.insert({dist[u], u});
            }
        }
    }
    return {dist, par};
}

int K = 200;


signed main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    cin >> n >> k >> m;
    vector<int> a(k);
    g.resize(n);
    for (int i = 0; i < k; i++) {
        cin >> a[i];
        a[i]--;
    }
    for (int i = 0; i < m; i++) {
        int v, u, w;
        cin >> v >> u >> w;
        g[v - 1].push_back({u - 1, w});
        g[u - 1].push_back({v - 1, w});
    }
    if (n <= 20) {
        int mini = inf;
        for (int ma = 0; ma < (1 << n); ma++) {
            bool ok = 1;
            for (auto el: a) {
                if (((ma >> el) & 1) == 0) {
                    ok = 0;
                }
            }
            if (!ok)
                continue;
            vector<int> ind;
            map<int, int> mp;
            for (int i = 0; i < n; i++) {
                if ((ma >> i) & 1) {
                    mp[i] = ind.size();
                    ind.push_back(i);
                }
            }
            int siz = ind.size();
            dsu dsu;
            dsu.init(siz);
            vector<vector<int>> edg;
            for (int i = 0; i < n; i++) {
                if (((ma >> i) & 1) == 1) {
                    for (auto [u, w]: g[i]) {
                        if ((ma >> u) & 1) {
                            edg.push_back({w, mp[i], mp[u]});
                        }
                    }
                }
            }
            sort(edg.begin(), edg.end());
            int sum = 0;
            int comp = siz;
            for (int i = 0; i < edg.size(); i++) {
                auto s = edg[i];
                if (dsu.get_par(s[1]) != dsu.get_par(s[2])) {
                    comp--;
                    sum += s[0];
                    dsu.merg(s[1], s[2]);
                }
            }
            if (comp == 1) {
                mini = min(mini, sum);
            }
        }
        cout << mini << '\n';
        return 0;
    }
    vector<vector<int>> di, pars;
    for (auto el: a) {
        auto spa = dijktra(el);
        di.push_back(spa.fi);
        pars.push_back(spa.se);
    }
    int mini = inf;
    if (k <= 3) {
        for (int i = 0; i < n; i++) {
            int val = 0;
            for (int j = 0; j < k; j++) {
                val += di[j][i];
            }
            mini = min(mini, val);
        }
        cout << mini << '\n';
    }else if (k == 4) {
        vector<vector<int>> stuff;
        for (int i = 0; i < n; i ++) {
            for (int blo = 0; blo < k; blo++) {
                int val = 0;
                for (int j = 0; j < k; j++) {
                    if (j == blo)
                        continue;
                    val += di[j][i];
                }
                stuff.push_back({val, i, blo});
            }
        }
        sort(stuff.begin(), stuff.end());
        for (int i = 0; i < 300; i++) {
            auto s = stuff[i];
            int val = s[0];
            int cen = s[1];
            int blo = s[2];
            vector<int> cov;
            for (int j = 0; j < k; j++) {
                if (j == blo)
                    continue;
                int cur = cen;
                while (cur != -1) {
                    cov.push_back(cur);
                    cur = pars[j][cur];
                }
            }
            int mbl = inf;
            for (auto pos : cov) {
                mbl = min(mbl, di[blo][pos]);
            }
            val += mbl;
            mini = min(mini, val);
        }
        cout << mini << '\n';
        return 0;
    } else {
        vector<vector<int>> stuff;
        vector<vector<int>> nst;
        for (int i = 0; i < n; i ++) {
            for (int bl1 = 0; bl1 < k; bl1++) {
                for (int bl2 = 0; bl2 < k; bl2++) {
                    int val = 0;
                    for (int j = 0; j < k; j++) {
                        if (j == bl1 || j == bl2)
                            continue;
                        val += di[j][i];
                    }
                    stuff.push_back({val, i, bl1, bl2});
                }
            }
            sort(stuff.begin(), stuff.end());
            nst.clear();
            for (int j = 0; j < min(K, (int)stuff.size()); j++) {
                nst.push_back(stuff[j]);
            }
            stuff = nst;
        }
        vector<vector<int>> stu;
        for (int i = 0; i < K; i++) {
            auto s= stuff[i];
            int val = s[0];
            int cen = s[1];
            int bl1 = s[2];
            int bl2 = s[3];
            vector<int> cov;
            for (int j = 0; j < k; j++) {
                if (j == bl1 || j == bl2)
                    continue;
                int cur = cen;
                while (cur != -1) {
                    cov.push_back(cur);
                    cur = pars[j][cur];
                }
            }
            for (auto pos : cov) {
                int nval = val + di[bl1][pos];
                stu.push_back({nval, cen, bl1, bl2, pos});
            }

            sort(stu.begin(), stu.end());
            nst.clear();
            for (int j = 0; j < min(K, (int)stu.size()); j++) {
                nst.push_back(stu[j]);
            }
            stu = nst;
        }
        for (int i = 0; i < K; i++) {
            auto s=stu[i];
            int val = s[0];
            int cen = s[1];
            int bl1 = s[2];
            int bl2 = s[3];
            int bp = s[4];
            vector<int> cov;
            for (int j = 0; j < k; j++) {
                if (j == bl1 || j == bl2)
                    continue;
                int cur = cen;
                while (cur != -1) {
                    cov.push_back(cur);
                    cur = pars[j][cur];
                }
            }
            int cur = bp;
            while (cur != -1) {
                cov.push_back(cur);
                cur = pars[bl1][cur];
            }
            int mbl = inf;
            for (auto pos : cov) {
                if (di[bl2][pos] < mbl) {
                    mbl = di[bl2][pos];
                }
            }
            val += mbl;
            mini = min(mini, val);
        }
        cout << mini << '\n';
    }
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...