// In the name of God
//#pragma GCC optimize("O2", "unroll-loops")
#include<bits/stdc++.h>
using namespace std;
#define pb push_back
#define fi first
#define se second
#define mp make_pair
typedef long long ll;
const int N = 1e6 + 5, lg = 21;
int n, c, s, t;
vector<int> adj[N];
vector<pair<int, int> > E;
int f[N], tin[N], tout[N], T, h[N], up[N][lg];
int dis[lg][N];
ll dp[1 << lg][lg];
bool ok[1 << lg][lg];
void dfs(int v, int p) {
tin[v] = ++T;
for (auto u : adj[v]) {
if (u != p) {
h[u] = h[v] + 1;
dfs(u, v);
}
}
tout[v] = ++T;
}
void dfs2(int v, int p, int k) {
for (auto u : adj[v]) {
if (u != p) {
dis[k][u] = dis[k][v] + 1;
dfs2(u, v, k);
}
}
}
bool anc(int v, int u) {
return tin[v] <= tin[u] && tout[v] >= tout[u];
}
int lca(int v, int u) {
if (tin[v] > tin[u])
swap(v, u);
if (anc(v, u))
return v;
for (int i = lg - 1; i >= 0; i--)
if (!anc(up[u][i], v))
u = up[u][i];
return up[u][0];
}
void solve() {
set<int> st;
cin >> n >> s >> t;
for (int i = 0; i < n - 1; i++) {
int v, u, w; cin >> v >> u >> w;
adj[v].pb(u), adj[u].pb(v);
if (w) {
f[(int)E.size()] = w, E.pb(mp(v, u));
st.insert(w);
}
}
st.insert(t);
dfs(s, 0);
tout[0] = ++T;
c = (int)E.size();
assert(c <= 20);
memset(dis, 63, sizeof dis);
for (int i = 0; i < c; i++) {
if (tin[E[i].fi] > tin[E[i].se])
swap(E[i].fi, E[i].se);
for (auto v : st) {
dis[i][v] = h[E[i].fi] + h[v] - 2 * h[lca(E[i].fi, v)];
}
}
E.pb(mp(t, t));
f[c] = t;
for (int mask = 0; mask < (1 << c); mask++) {
for (int i = 0; i <= c; i++) {
int v = E[i].fi;
ok[mask][i] = true;
for (int j = 0; j < c; j++) {
ok[mask][i] &= (mask & (1 << j)) || (!anc(E[j].se, v) && !anc(E[j].se, f[i]));
}
}
}
memset(dp, 63, sizeof dp);
for (int i = 0; i < c; i++) {
if (ok[1 << i][i]) {
dp[1 << i][i] = h[f[i]] + dis[i][f[i]];
}
}
ll ans = 1e18;
if (ok[0][c]) {
ans = h[t];
}
for (int mask = 1; mask < (1 << c); mask++) {
for (int i = 0; i < c; i++) {
if (!(mask & (1 << i)))
continue;
int mask2 = mask ^ (1 << i);
for (int j = 0; j < c; j++) {
if (mask2 & (1 << j)) {
if (ok[mask2][i]) {
dp[mask][i] = min(dp[mask][i], dp[mask2][j] + dis[j][f[i]] + dis[i][f[i]]);
}
}
}
if (ok[mask][c]) {
ans = min(ans, dp[mask][i] + dis[i][t]);
}
}
}
if (ans < 1e18)
cout << ans << '\n';
else
cout << -1 << '\n';
}
int32_t main() {
ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
solve();
return 0;
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
146 ms |
450656 KB |
Output is correct |
2 |
Correct |
145 ms |
450672 KB |
Output is correct |
3 |
Incorrect |
185 ms |
455708 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
179 ms |
450672 KB |
Output is correct |
2 |
Correct |
150 ms |
450732 KB |
Output is correct |
3 |
Incorrect |
153 ms |
450716 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
146 ms |
450656 KB |
Output is correct |
2 |
Correct |
145 ms |
450672 KB |
Output is correct |
3 |
Incorrect |
185 ms |
455708 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |