Submission #1178357

#TimeUsernameProblemLanguageResultExecution timeMemory
1178357ezzzayCities (BOI16_cities)C++20
51 / 100
6091 ms22892 KiB
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define ff first
#define ss second
#define pb push_back
const int N=3e5+5;
vector<pair<int,int>>v[N];
int cit[10];
int dist[N];
int n,m,k;
int par[25];
int sbtr[25];
int tmp1[N],tmp2[N],tmp3[N];
void find(int x){
    for(int i=1;i<=n;i++){
        dist[i]=1e15;
       
    }
    
    dist[x]=0;
    
    priority_queue<pair<int,int>>q;
    q.push({0,x});
    while(!q.empty()){
        int w=-q.top().ff;
        int a=q.top().ss;
        q.pop();
        if(dist[a]<w)continue;
        for(auto p:v[a]){
            int b=p.ff;
            int c=p.ss;
            if(dist[b]>dist[a]+c){
                dist[b]=dist[a]+c;
                q.push({-dist[b],b});
            }
        }
    }
}
int dst[1005][1005];
void fun(int x){
    for(int i=1;i<=n;i++){
        dst[x][i]=1e15;
    }
    dst[x][x]=0;
    priority_queue<pair<int,int>>q;
    q.push({0,x});
    while(!q.empty()){
        int w=-q.top().ff;
        int a=q.top().ss;
        q.pop();
        if(dst[x][a]<w)continue;
        for(auto p:v[a]){
            int b=p.ff;
            int c=p.ss;
            if(dst[x][b]>dst[x][a]+c){
                dst[x][b]=dst[x][a]+c;
                q.push({-dst[x][b],b});
            }
        }
    }
}
int findpar(int x){
    if(par[x]==x)return x;
    return par[x]=findpar(par[x]);
}
 
void dsu(){
    vector<int>vc;
    int ans=1e15;
    for(int i=0;i<k;i++){
        int a;
        cin>>a;
        vc.pb(a);
    }
    for(int i=0;i<m;i++){
        int a,b,c;
        cin>>a>>b>>c;
        v[a].pb({b,c});
        v[b].pb({a,c});
    }
    for(int i=0;i<(1<<n);i++){
        bool init[25];
        int cnt=0;
        for(int i=0;i<25;i++){
            init[i]=0;
        }
        for(int j=0;j<n;j++){
            if(i & (1<<j)){
                init[j+1]=1;
                cnt++;
            }
        }
        bool u=1;
        for(auto c:vc)if(init[c]==0)u=0;
        if(u==0)continue;
        vector<pair<int,pair<int,int>>>tmp;
        
        for(int j=1;j<=n;j++){
            if(init[j]){
                for(auto p:v[j]){
                    if(init[p.ff]){
                        tmp.pb({p.ss,{j,p.ff}});
                    }
                }
            }
        }
        
        for(int j=1;j<=n;j++){
            par[j]=j;
            sbtr[j]=1;
        }
        int s=0;
        sort(tmp.begin(),tmp.end());
        for(auto p:tmp){
            int w=p.ff;
            int x=p.ss.ff;
            int y=p.ss.ss;
            int px=findpar(x);
            int py=findpar(y);
            
            if(px==py)continue;
            
            sbtr[px]+=sbtr[py];
            par[py]=px;
            s+=w;
            cnt--;
        }
        if(cnt==1)ans=min(ans,s);
    }
    cout<<ans<<endl;
}
signed main(){
    cin>>n>>k>>m;
    if(n<=20){
        dsu();
        return 0;
    }
    vector<int>vc;
    for(int i=1;i<=k;i++){
        int a;
        cin>>a;
        vc.pb(a);
    }
    for(int i=1;i<=m;i++){
        int a,b,c;
        cin>>a>>b>>c;
        v[a].pb({b,c});
        v[b].pb({a,c});
    }
    
    if(k==2){
        find(vc[0]);
        for(int i=1;i<=n;i++)tmp1[i]=dist[i];
        find(vc[1]);
        for(int i=1;i<=n;i++)tmp2[i]=dist[i];
        int ans=1e18;
        for(int i=1;i<=n;i++){
            ans=min(ans,tmp1[i]+tmp2[i]);
        }
        cout<<ans;
    }
    else if(k==3){
        find(vc[0]);
        for(int i=1;i<=n;i++)tmp1[i]=dist[i];
        find(vc[1]);
        for(int i=1;i<=n;i++)tmp2[i]=dist[i];
        find(vc[2]);
        for(int i=1;i<=n;i++)tmp3[i]=dist[i];
        
        int ans=1e18;
 
        for(int i=1;i<=n;i++){
            ans=min(ans,tmp1[i]+tmp2[i]+tmp3[i]);
           
        }
        cout<<ans;
    }
    else{
        for(int i=1;i<=n;i++)fun(i);
        int ans=1e18;
        for(int i=1;i<=n;i++){
            for(int j=1;j<=n;j++){
                ans=min(ans,dst[i][j]+dst[vc[0]][i]+dst[vc[1]][i]+dst[vc[2]][j]+dst[vc[3]][j]);
                ans=min(ans,dst[i][j]+dst[vc[0]][i]+dst[vc[2]][i]+dst[vc[1]][j]+dst[vc[3]][j]);
                ans=min(ans,dst[i][j]+dst[vc[0]][i]+dst[vc[3]][i]+dst[vc[2]][j]+dst[vc[1]][j]);
            }
        }
        cout<<ans;
    }
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...