This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
// In the name of God
//#pragma GCC optimize("O2", "unroll-loops")
#include<bits/stdc++.h>
using namespace std;
#define pb push_back
#define fi first
#define se second
#define mp make_pair
typedef long long ll;
const int N = 1e6 + 5, lg = 21;
int n, c, s, t;
vector<int> adj[N];
vector<pair<int, int> > E;
int f[N], tin[N], tout[N], T, h[N];
int dis[lg][N];
ll dp[1 << lg][lg];
bool ok[1 << lg][lg];
void dfs(int v, int p) {
tin[v] = ++T;
for (auto u : adj[v]) {
if (u != p) {
h[u] = h[v] + 1;
dfs(u, v);
}
}
tout[v] = ++T;
}
void dfs2(int v, int p, int k) {
for (auto u : adj[v]) {
if (u != p) {
dis[k][u] = dis[k][v] + 1;
dfs2(u, v, k);
}
}
}
bool anc(int v, int u) {
return tin[v] <= tin[u] && tout[v] >= tout[u];
}
void solve() {
cin >> n >> s >> t;
for (int i = 0; i < n - 1; i++) {
int v, u, w; cin >> v >> u >> w;
adj[v].pb(u), adj[u].pb(v);
if (w)
f[(int)E.size()] = w, E.pb(mp(v, u));
}
dfs(s, 0);
tout[0] = ++T;
c = (int)E.size();
assert(c <= 20);
memset(dis, 63, sizeof dis);
for (int i = 0; i < c; i++) {
if (tin[E[i].fi] > tin[E[i].se])
swap(E[i].fi, E[i].se);
queue<int> q;
q.push(E[i].fi);
dis[i][E[i].fi] = 0;
while (!q.empty()) {
int v = q.front();
q.pop();
for (auto u : adj[v]) {
if (dis[i][u] > dis[i][v] + 1) {
dis[i][u] = dis[i][v] + 1;
q.push(u);
}
}
}
}
E.pb(mp(t, t));
f[c] = t;
for (int mask = 0; mask < (1 << c); mask++) {
for (int i = 0; i <= c; i++) {
int v = E[i].fi;
ok[mask][i] = true;
for (int j = 0; j < c; j++) {
ok[mask][i] &= (mask & (1 << j)) || (!anc(E[j].se, v) && !anc(E[j].se, f[i]));
}
}
}
memset(dp, 63, sizeof dp);
for (int i = 0; i < c; i++) {
if (ok[1 << i][i]) {
dp[1 << i][i] = h[f[i]] + dis[i][f[i]];
}
}
ll ans = 1e18;
if (ok[0][c]) {
ans = h[t];
}
for (int mask = 1; mask < (1 << c); mask++) {
for (int i = 0; i < c; i++) {
if (!(mask & (1 << i)))
continue;
int mask2 = mask ^ (1 << i);
for (int j = 0; j < c; j++) {
if (mask2 & (1 << j)) {
if (ok[mask2][i]) {
dp[mask][i] = min(dp[mask][i], dp[mask2][j] + dis[j][f[i]] + dis[i][f[i]]);
}
}
}
if (ok[mask][c]) {
ans = min(ans, dp[mask][i] + dis[i][t]);
}
}
}
if (ans < 1e18)
cout << ans << '\n';
else
cout << -1 << '\n';
}
int32_t main() {
ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
solve();
return 0;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |