#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
#define pb push_back
#define lc (id * 2)
#define rc (lc + 1)
#define S second
#define F first
#define migmig cin.tie(0); cout.tie(0); ios::sync_with_stdio(false)
// #define int ll
#pragma GCC optimize("Ofast,unroll-loops,inline")
#pragma GCC target("avx2,bmi2,lzcnt")
const int maxn = 1e6 + 10;
const int mod = 1e9 + 7;
const int LOG = 21;
const int inf = 1e9 + 10;
int n, m, s, t, dp[LOG][(1 << LOG)], par[maxn][LOG], qmask[maxn], h[maxn], f[maxn], dis[LOG][LOG], key[maxn];
int fs[maxn];
// f[i] = dis(i, t);
vector<pii> G[maxn], edge;
vector<int> e;
void dfs(int v, int p) {
par[v][0] = p;
for (int i = 1; i < LOG; i++)
par[v][i] = par[par[v][i - 1]][i - 1];
for (auto [u, w] : G[v]) {
if (u == p) continue;
if (w >= 0)
qmask[u] = qmask[v] | (1 << w);
else
qmask[u] = qmask[v];
h[u] = h[v] + 1;
dfs(u, v);
}
}
int get_par(int v, int d) {
int res = v;
for (int i = 0; i < LOG; i++) {
if ((1 << i) & d) {
res = par[res][i];
}
}
return res;
}
int LCA(int v, int u) {
if (h[v] > h[u])
swap(v, u);
u = get_par(u, h[u] - h[v]);
if (u == v)
return u;
for (int i = LOG - 1; i >= 0; i--) {
if (par[v][i] != par[u][i]) {
v = par[v][i];
u = par[u][i];
}
}
return par[v][0];
}
int dist(int v, int u) {
int lca = LCA(v, u);
return h[v] + h[u] - 2 * h[lca];
}
int path(int v, int u) {
return qmask[v] ^ qmask[u];
}
int P[maxn];
signed main() {
migmig;
cin >> n >> s >> t;
int ted = 0;
for (int i = 0; i < n - 1; i++) {
int v, u, w;
cin >> v >> u >> w;
key[i] = w;
if (w > 0)
ted++;
int x = ted - 1;
if (w == 0)
x = -1;
G[v].pb({u, x});
G[u].pb({v, x});
edge.pb({v, u});
if (w > 0)
e.pb(i);
}
dfs(s, 0);
for (int i = 0; i < n - 1; i++) {
int v = edge[i].F, u = edge[i].S;
if (v == par[u][0])
P[i] = v;
else
P[i] = u;
}
for (int i = 1; i <= n; i++)
f[i] = dist(i, t), fs[i] = dist(i, s);
for (int i = 0; i < e.size(); i++) {
for (int j = 0; j < e.size(); j++) {
dis[i][j] = dist(P[e[i]], key[e[j]]) + dist(key[e[j]], P[e[j]]);
}
}
// swap(s, t);
m = int(e.size());
int ans = inf;
for (int mask = 0; mask < (1 << m); mask++) {
for (int i = 0; i < m; i++) {
//dp[i][mask]
dp[i][mask] = inf;
if ((path(t, P[e[i]]) & mask) == 0) {
dp[i][mask] = f[P[e[i]]];
continue;
}
for (int bt = 0; bt < m; bt++) {
if (((1 << bt) & mask) == 0) continue;
if ((path(key[e[bt]], P[e[bt]]) & mask) == 0 && ((path(P[e[i]], key[e[bt]]) & mask) == 0 ))
dp[i][mask] = min(dp[i][mask], dis[i][bt] + dp[bt][mask ^ (1 << bt)]);
}
}
}
for (int i = 0; i < m; i++) {
int mask = (1 << m) - 1;
if (path(s, key[e[i]]) == 0 && path(key[e[i]], P[e[i]]) == 0) {
ans = min(ans, dist(s, key[e[i]]) + dist(key[e[i]], P[e[i]]) + dp[i][mask ^ (1 << i)]);
}
}
if (path(s, t) == 0)
ans = min(ans, dist(s, t));
// for (int mask = 0; mask < (1 << m); mask++) {
// for (int i = 0; i < m; i++) {
// //dp[i][mask]
// cout << i << ' ' << mask << ' ' << dp[i][mask] << '\n';
// }
// }
if (ans >= inf)
ans = -1;
cout << ans << '\n';
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |