답안 #646250

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
646250 2022-09-29T10:10:22 Z fatemetmhr Cities (BOI16_cities) C++17
51 / 100
6000 ms 25532 KB
#include <bits/stdc++.h>

using namespace std;

typedef long long ll;

#define all(x)  x.begin(), x.end()
#define pb      push_back
#define fi      first
#define se      second
#define mp      make_pair

const int maxn5 = 2e5 + 10;
const int maxn3 = 1e3 + 10;
const ll  inf   = 1e18;

int n;
vector <int> imp;
vector <pair<int, ll>> adj[maxn5];
set <pair<ll, int>> av;
ll dis[5][maxn5], dis2[maxn3][maxn3];
bool mark[23];
int par[30];
vector <pair<ll, pair<int, int>>> ed;

inline void dij(int r, int id){
    av.clear();
    for(int i = 0; i < n; i++)
        dis[id][i] = inf;
    dis[id][r] = 0;
    for(int i = 0; i < n; i++)
        av.insert({dis[id][i], i});
    while(av.size()){
        int v = av.begin() -> se;
        av.erase(av.begin());
        for(auto p : adj[v]){
            int u = p.fi;
            ll w = p.se;
            if(dis[id][u] < dis[id][v] + w)
                continue;
            av.erase({dis[id][u], u});
            dis[id][u] = dis[id][v] + w;
            av.insert({dis[id][u], u});
        }
    }
    return;
}

inline void dij2(int r, int id){
    av.clear();
    for(int i = 0; i < n; i++)
        dis2[id][i] = inf;
    dis2[id][r] = 0;
    for(int i = 0; i < n; i++)
        av.insert({dis2[id][i], i});
    while(av.size()){
        int v = av.begin() -> se;
        av.erase(av.begin());
        for(auto p : adj[v]){
            int u = p.fi;
            ll w = p.se;
            if(dis2[id][u] < dis2[id][v] + w)
                continue;
            av.erase({dis2[id][u], u});
            dis2[id][u] = dis2[id][v] + w;
            av.insert({dis2[id][u], u});
        }
    }
    return;
}

int get_par(int a){return par[a] == -1 ? a : par[a] = get_par(par[a]);}

inline ll mst(){
    fill(par, par + n + 5, -1);
    ed.clear();
    int cnt = 0;
    for(int i = 0; i < n; i++) if(mark[i]){
        cnt++;
        for(auto p : adj[i]) if(i < p.fi && mark[p.fi])
            ed.pb({p.se, {i, p.fi}});
    }
    sort(all(ed));
    ll sum = 0;
    for(auto p : ed){
        ll w = p.fi;
        int u = p.se.fi, v = p.se.se;
        u = get_par(u); v = get_par(v);
        if(u == v)
            continue;
        cnt--;
        sum += w;
        par[u] = v;
    }
    if(cnt > 1)
        return inf;
    return sum;
}

int main(){
    ios_base::sync_with_stdio(false); cin.tie(0); cout.tie(0);

    int k, m; cin >> n >> k >> m;
    for(int i = 0; i < k; i++){
        int a; cin >> a; a--;
        imp.pb(a);
        mark[a] = true;
    }
    for(int i = 0; i < m; i++){
        int a, b, c; cin >> a >> b >> c;
        a--; b--;
        adj[a].pb({b, c});
        adj[b].pb({a, c});
    }

    if(k == 5){
        if(n > 30)
            return 0;
        vector <int> ver;
        ver.clear();
        ll ans = inf;
        for(int i = 0; i < n; i++) if(!mark[i])
            ver.pb(i);
        for(int mask = 0; mask < (1 << (n - k)); mask++){
            memset(mark, false, sizeof mark);
            for(int i = 0; i < n - k; i++) if((mask >> i)&1)
                mark[ver[i]] = true;
            for(auto u : imp)
                mark[u] = true;
            ans = min(ans, mst());
        }
        return cout << ans << endl, 0;
    }

    for(int i = 0; i < k; i++)
        dij(imp[i], i);

    if(k == 2){
        cout << dis[0][imp[1]] << endl;
        return 0;
    }

    if(k == 3){
        ll ans = dis[0][imp[1]] + dis[1][imp[2]];
        for(int i = 0; i < n; i++)
            ans = min(ans, dis[0][i] + dis[1][i] + dis[2][i]);
        return cout << ans << endl, 0;
    }

    for(int i = 0; i < n; i++)
        dij2(i, i);

    ll ans = inf;
    //cout << "check " << endl;
    int a0 = imp[0], a1 = imp[1], a2 = imp[2], a3 = imp[3];
    for(int i = 0; i < n; i++) for(int j = 0; j < n; j++){
        ans = min(ans, dis2[a0][i] + dis2[a1][i] + dis2[a2][j] + dis2[a3][j] + dis2[i][j]);
        ans = min(ans, dis2[a0][i] + dis2[a1][j] + dis2[a2][i] + dis2[a3][j] + dis2[i][j]);
        ans = min(ans, dis2[a0][i] + dis2[a1][j] + dis2[a2][j] + dis2[a3][i] + dis2[i][j]);
        //cout << i << ' ' << j << ' ' << ans << endl;
    }
    //for(auto u : imp)
    //    cout << u << ' ' << dis2[u][0] << endl;
    //cout << dis2[0][1] << ' ' << dis2[0][2] << ' ' << dis2[0][3] << ' ' << dis2[0][4] << ' ' << dis2[0][5] << endl;
    cout << ans << endl;
}
# 결과 실행 시간 메모리 Grader output
1 Correct 3 ms 4948 KB Output is correct
2 Correct 3 ms 4948 KB Output is correct
3 Correct 3 ms 4972 KB Output is correct
4 Correct 3 ms 5076 KB Output is correct
5 Correct 25 ms 5016 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 539 ms 23592 KB Output is correct
2 Correct 554 ms 22852 KB Output is correct
3 Correct 280 ms 18824 KB Output is correct
4 Correct 59 ms 13900 KB Output is correct
5 Correct 400 ms 22828 KB Output is correct
6 Correct 55 ms 13780 KB Output is correct
7 Correct 4 ms 5204 KB Output is correct
8 Correct 4 ms 5204 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 496 ms 13112 KB Output is correct
2 Correct 469 ms 13252 KB Output is correct
3 Correct 299 ms 13132 KB Output is correct
4 Correct 145 ms 9092 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Execution timed out 6006 ms 25532 KB Time limit exceeded
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 111 ms 15420 KB Output isn't correct
2 Halted 0 ms 0 KB -