#pragma GCC optimize("Ofast" , "unroll-loops")
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 1e6 + 16 , MAXM = 20 , inf = 1e9;
int n , s , t;
vector<pair<int,int>> adj[MAXN];
int x[MAXM] , k[MAXM] , cur = 0;
map<pair<int,int>,int> kid;
int needs[MAXN] , diss[MAXN];
void dfs0(int v , int p){
for(auto [u , w] : adj[v]){
if(u == p) continue;
needs[u] = needs[v];
if(w != 0){
x[cur] = v;
k[cur] = w-1;
kid[{v , u}] = kid[{u , v}] = cur;
needs[u] |= (1 << cur);
cur++;
}
diss[u] = diss[v] + 1;
dfs0(u , v);
}
}
int need[MAXM][MAXN] , dis[MAXM][MAXN];
void dfs1(int v , int p , int i){
for(auto [u , w] : adj[v]){
if(u == p) continue;
need[i][u] = need[i][v];
if(w != -1) need[i][u] |= (1 << w);
dis[i][u] = dis[i][v] + 1;
dfs1(u , v , i);
}
}
int dp[1<<MAXM][MAXM];
inline bool sub(const int& m1 , const int& m2){
return ((m1 & m2) == m1);
}
int ans = inf;
int main(){
ios::sync_with_stdio(false) , cin.tie(0);
cin >> n;
cin >> s >> t;
s-- , t--;
for(int i = 0 ; i < n-1 ; i++){
int u , v , w;
cin >> u >> v >> w;
u-- , v--;
adj[u].push_back({v , w});
adj[v].push_back({u , w});
}
dfs0(s , s);
for(int v = 0 ; v < n ; v++){
for(auto& [u , w] : adj[v]){
if(w != 0){
w = kid[{v , u}];
}
else{
w = -1;
}
}
}
for(int i = 0 ; i < cur ; i++){
dfs1(x[i] , x[i] , i);
}
for(int msk = 0 ; msk < (1 << cur) ; msk++){
for(int i = 0 ; i < cur ; i++){
dp[msk][i] = inf;
}
}
for(int msk = (1 << cur) - 1 ; msk >= 0 ; msk--){
int flag = 1;
for(int i = 0 ; i < cur ; i++){
if(sub(1 << i , msk) and (!sub(needs[x[i]] , msk) or !sub(needs[k[i]] , msk))) flag = 0;
}
if(flag == 0) continue;
for(int i = 0 ; i < cur ; i++){
if(!sub(1 << i , msk)) continue;
if(sub(need[i][t] , msk)) dp[msk][i] = dis[i][t];
else{
for(int j = 0 ; j < cur ; j++){
if(sub(1 << j , msk)) continue;
if(sub(need[i][k[j]] , msk) and sub(need[j][k[j]] , msk)){
dp[msk][i] = min(dp[msk][i] , dis[i][k[j]] + dis[j][k[j]] + dp[msk | (1 << j)][j]);
}
}
}
}
}
if(needs[t] == 0){
cout << diss[t] << endl;
}
else{
for(int i = 0 ; i < cur ; i++){
if(needs[k[i]] == 0 and need[i][k[i]] == 0){
ans = min(ans , diss[k[i]] + dis[i][k[i]] + dp[1 << i][i]);
}
}
if(ans == inf) cout << -1 << endl;
else cout << ans << endl;
}
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |