#include <bits/stdc++.h>
using namespace std;
using ll = long long;
using ld = long double;
using pl = pair<ll,ll>;
using pii = pair<int,int>;
using tpl = tuple<int,int,int>;
#define all(a) a.begin(), a.end()
#define filter(a) a.erase(unique(all(a)), a.end())
const int mn = 1e6 + 6;
int dp[1 << 20][40], nodeA[mn], nodeB[mn], nodeLock[mn], dstA[20][20], dstB[20][20], mskA[20][20], mskB[20][20];
int spt[mn][20], chain[mn], sz[mn], num[mn], depth[mn], par[mn], timeDfs;
vector<int> adj[mn];
int szDfs (int u, int p) {
sz[u] = 1;
for (int v : adj[u])
if (v != p) sz[u] += szDfs(v, u);
return sz[u];
}
void dfs (int u, int p, int d, bool toP) {
if (u == 1) szDfs(u, p);
chain[u] = (toP ? chain[p] : u), par[u] = p;
depth[u] = d, num[u] = ++timeDfs;
if (adj[u].size() == (u > 1)) return;
// process biggest subtree
int big = *max_element(all(adj[u]), [&] (int i, int j) {
return (i == p || j == p ? i == p : sz[i] < sz[j]);
});
dfs(big, u, d + 1, 1);
// process other subtrees
for (int v : adj[u])
if (v != p && v != big) dfs(v, u, d + 1, 0);
}
int lca (int a, int b) {
while (chain[a] != chain[b]) {
int ap = par[chain[a]], bp = par[chain[b]];
if (depth[ap] == depth[bp]) a = ap, b = bp;
else if (depth[ap] > depth[bp]) a = ap;
else if (depth[bp] > depth[ap]) b = bp;
}
return (depth[a] < depth[b] ? a : b);
}
int query (int u, int v) {
if (u > v) return 0;
int p = 31 - __builtin_clz(v - u + 1);
return spt[u][p] | spt[v - (1 << p) + 1][p];
}
int getMask (int a, int b) {
int mask = 0;
function<void(int&)> lift = [&] (int &u) {
mask |= query(num[chain[u]] + 1, num[u]);
u = par[chain[u]];
};
while (chain[a] != chain[b]) {
int ap = par[chain[a]], bp = par[chain[b]];
if (depth[ap] == depth[bp]) lift(a), lift(b);
else if (depth[ap] > depth[bp]) lift(a);
else if (depth[bp] > depth[ap]) lift(b);
}
if (depth[a] > depth[b]) swap(a, b);
return mask | query(num[a] + 1, num[b]);
}
bool check (int u, int v, int mask) {
int needMask = getMask(u, v);
return (needMask & mask) == needMask;
}
bool checkFast (int sub, int mask) { return (sub & mask) == sub; }
int dist (int u, int v) {
return depth[u] + depth[v] - 2 * depth[lca(u, v)];
}
int main()
{
ios::sync_with_stdio(0);
cin.tie(0);
/// read input
int n, S, T, counter = 0;
cin >> n >> S >> T;
vector<pii> locked;
for (int i = 1; i < n; i++) {
int a, b, c; cin >> a >> b >> c;
adj[a].push_back(b);
adj[b].push_back(a);
if (c) {
locked.emplace_back(a, b);
nodeA[counter] = a, nodeB[counter] = b;
nodeLock[counter++] = c;
}
}
dfs(1, 0, 1, 0);
assert(counter <= 20);
/// build sparse table
for (int i = 0; i < counter; i++) {
int a, b; tie(a, b) = locked[i];
if (depth[a] < depth[b]) swap(a, b);
spt[num[a]][0] = (1 << i);
}
for (int s = 1; (1 << s) <= n; s++) {
int p = s - 1;
for (int i = 1; i + (1 << s) - 1 <= n; i++)
spt[i][s] = spt[i][p] | spt[i + (1 << p)][p];
}
/// pre-calculations
for (int i = 0; i < counter; i++) {
for (int j = 0; j < counter; j++) {
dstA[i][j] = dist(nodeLock[i], nodeA[j]);
dstB[i][j] = dist(nodeLock[i], nodeB[j]);
mskA[i][j] = getMask(nodeLock[i], nodeA[j]);
mskB[i][j] = getMask(nodeLock[i], nodeB[j]);
}
}
/// setup base cases for DP
for (int mask = 0; mask < (1 << counter); mask++)
for (int j = 0; j < 2 * counter; j++) dp[mask][j] = INT_MAX;
for (int i = 0; i < counter; i++) {
int mask = (1 << i);
if (check(S, nodeLock[i], 0)) {
if (checkFast(mskA[i][i], 0)) dp[mask][i] = dist(S, nodeLock[i]) + dstA[i][i];
if (checkFast(mskB[i][i], 0)) dp[mask][i + counter] = dist(S, nodeLock[i]) + dstB[i][i];
}
}
/// run DP
int full = (1 << counter) - 1, ans = INT_MAX;
for (int mask = 1; mask < (1 << counter); mask++) {
for (int sub = mask; sub; sub -= sub & (-sub)) {
int last = __builtin_ctz(sub);
if (dp[mask][last] == INT_MAX && dp[mask][last + counter] == INT_MAX) continue;
for (int sub2 = mask ^ full; sub2; sub2 -= sub2 & (-sub2)) {
int nxt = __builtin_ctz(sub2), dA = dstA[nxt][nxt], dB = dstB[nxt][nxt];
bool checkA = checkFast(mskA[nxt][nxt], mask), checkB = checkFast(mskB[nxt][nxt], mask);
if (checkFast(mskA[nxt][last], mask) && dp[mask][last] != INT_MAX) {
int cur = dp[mask][last] + dstA[nxt][last];
if (checkA)
dp[mask | (1 << nxt)][nxt] = min(dp[mask | (1 << nxt)][nxt], cur + dA);
if (checkB)
dp[mask | (1 << nxt)][nxt + counter] = min(dp[mask | (1 << nxt)][nxt + counter], cur + dB);
}
if (checkFast(mskB[nxt][last], mask) && dp[mask][last + counter] != INT_MAX) {
int cur = dp[mask][last + counter] + dstB[nxt][last];
if (checkA)
dp[mask | (1 << nxt)][nxt] = min(dp[mask | (1 << nxt)][nxt], cur + dA);
if (checkB)
dp[mask | (1 << nxt)][nxt + counter] = min(dp[mask | (1 << nxt)][nxt + counter], cur + dB);
}
}
if (check(nodeA[last], T, mask) && dp[mask][last] != INT_MAX)
ans = min(ans, dp[mask][last] + dist(nodeA[last], T));
if (check(nodeB[last], T, mask) && dp[mask][last + counter] != INT_MAX)
ans = min(ans, dp[mask][last + counter] + dist(nodeB[last], T));
}
}
cout << (ans == INT_MAX ? -1 : ans) << "\n";
return 0;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |