Submission #635450

#TimeUsernameProblemLanguageResultExecution timeMemory
635450S2speedLOSTIKS (INOI20_lostiks)C++17
100 / 100
1545 ms247468 KiB
#include<bits/stdc++.h> using namespace std; #pragma GCC optimize ("Ofast") #define sze(x) (int)(x.size()) typedef long long ll; typedef pair<ll , ll> pll; typedef pair<int , int> pii; typedef pair<pii , int> piii; const ll maxn = (1 << 20) + 17 , inf = 2e8; int n , s , t; pii adj[maxn << 1]; vector<piii> ed; int dg[maxn] , st[maxn] , ft[maxn] , x[maxn]; int dis[2][maxn]; int bfs[maxn] , sz = 0; int g[maxn] , k[22] , d[22] , e[22] , f[22][22] , dp[maxn][22] , ind[22]; int jad[maxn][20]; void rBFS(int r , int h){ sz = 0; dis[h][r] = 0; g[r] = 0; bfs[sz++] = r; int x = 0; while(x < sz){ int v = bfs[x++]; for(int e = st[v] ; e < ft[v] ; e++){ pii p = adj[e]; int i = p.first , t = p.second; if(dis[h][i] < dis[h][v] + 1) continue; jad[i][0] = v; for(int j = 1 ; j < 20 ; j++){ if(jad[i][j - 1] == -1) continue; jad[i][j] = jad[jad[i][j - 1]][j - 1]; } g[i] = g[v]; if(t != -1){ g[i] ^= (1 << t); d[t] = v; } dis[h][i] = dis[h][v] + 1; bfs[sz++] = i; } } return; } int find_jad(int v , int d){ d = dis[0][v] - d; for(int j = 0 ; j < 20 ; j++){ if(d & (1 << j)) v = jad[v][j]; } return v; } int lca(int v , int u){ if(dis[0][v] > dis[0][u]) swap(v , u); u = find_jad(u , dis[0][v]); if(v == u) return v; for(int j = 19 ; ~j ; j--){ if(jad[v][j] != jad[u][j]){ v = jad[v][j]; u = jad[u][j]; } } return jad[v][0]; } int dist(int v , int u){ int l = lca(v , u); return dis[0][v] + dis[0][u] - (dis[0][l] << 1); } void BFS(int r , int h){ sz = 0; dis[h][r] = 0; bfs[sz++] = r; int x = 0; while(x < sz){ int v = bfs[x++]; for(int e = st[v] ; e < ft[v] ; e++){ pii p = adj[e]; int i = p.first; if(dis[h][i] < dis[h][v] + 1) continue; dis[h][i] = dis[h][v] + 1; bfs[sz++] = i; } } return; } int main(){ ios_base::sync_with_stdio(false); cin.tie(0); cout.tie(0); memset(jad , -1 , sizeof(jad)); memset(dp , 63 , sizeof(dp)); memset(dis , 63 , sizeof(dis)); cin>>n>>s>>t; s--; t--; int m = 0; for(int i = 1 ; i < n ; i++){ int v , u , w; cin>>v>>u>>w; v--; u--; w--; dg[v]++; dg[u]++; if(w != -1){ k[m] = w; ed.push_back({{v , u} , m}); m++; } else { ed.push_back({{v , u} , -1}); } } x[0] = st[0] = 0; ft[0] = dg[0]; for(int i = 1 ; i < n ; i++){ x[i] = st[i] = ft[i - 1]; ft[i] = st[i] + dg[i]; } for(auto p : ed){ int v = p.first.first , u = p.first.second , w = p.second; adj[x[v]++] = {u , w}; adj[x[u]++] = {v , w}; } rBFS(s , 0); if(g[t] == 0){ cout<<dis[0][t]<<'\n'; return 0; } BFS(t , 1); for(int j = 0 ; j < m ; j++){ e[j] = g[k[j]] | g[d[j]]; if(e[j] == 0){ dp[(1 << j)][j] = dis[0][k[j]] + dist(d[j] , k[j]); } } for(int i = 0 ; i < m ; i++){ for(int j = 0 ; j < m ; j++){ f[i][j] = dist(d[i] , k[j]) + dist(d[j] , k[j]); } } int lm = (1 << m); for(int mask = 3 ; mask < lm ; mask++){ int bp = 0; for(int j = 0 ; j < m ; j++){ if(mask & (1 << j)) ind[bp++] = j; } if(bp == 1) continue; for(int jj = 0 ; jj < bp ; jj++){ int j = ind[jj]; if((e[j] & mask) != e[j]) continue; int h = inf , msk = mask ^ (1 << j); for(int ii = 0 ; ii < bp ; ii++){ int i = ind[ii]; if(i == j) continue; h = min(h , dp[msk][i] + f[i][j]); } dp[mask][j] = h; } } int ans = inf; for(int mask = 1 ; mask < lm ; mask++){ if((g[t] & mask) != g[t]) continue; int h = inf; for(int i = 0 ; i < m ; i++){ if(!(mask & (1 << i))) continue; // cout<<mask<<' '<<i<<'\n'; h = min(h , dp[mask][i] + dis[1][d[i]]); } ans = min(ans , h); } cout<<(ans == inf ? -1 : ans)<<'\n'; return 0; }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...