#include <bits/stdc++.h>
using namespace std;
using ll = long long;
using ld = long double;
using pl = pair<ll,ll>;
using pii = pair<int,int>;
using tpl = tuple<int,int,int>;
#define all(a) a.begin(), a.end()
#define filter(a) a.erase(unique(all(a)), a.end())
const int mn = 1e6 + 6;
int spt[mn][21], num[mn], chain[mn], depth[mn], par[mn], sz[mn], timeDfs;
int dp[1 << 20][40], preDist[20][40], preMask[20][40];
vector<int> adj[mn];
namespace preprocess {
int szDfs (int u, int p) {
sz[u] = 1;
for (int v : adj[u])
if (v != p) sz[u] += szDfs(v, u);
return sz[u];
}
void dfs (int u, int p, int d, bool toP) {
if (u == 1) szDfs(u, p);
chain[u] = (toP ? chain[p] : u), par[u] = p;
num[u] = ++timeDfs, depth[u] = d;
if (adj[u].size() == (u > 1)) return; // leaf
int big = *max_element(all(adj[u]), [&] (int i, int j) {
return (i == p || j == p ? i == p : sz[i] < sz[j]);
});
dfs(big, u, d + 1, 1);
for (int v : adj[u])
if (v != p && v != big) dfs(v, u, d + 1, 0);
}
};
int querySparse (int l, int r) {
if (l > r) return 0;
int p = 31 - __builtin_clz(r - l + 1);
return spt[l][p] | spt[r - (1 << p) + 1][p];
}
int lca (int a, int b) {
while (chain[a] != chain[b]) {
int ap = par[chain[a]], bp = par[chain[b]];
if (depth[ap] == depth[bp]) a = ap, b = bp;
else if (depth[ap] > depth[bp]) a = ap;
else if (depth[bp] > depth[ap]) b = bp;
}
return (depth[a] < depth[b] ? a : b);
}
int getMask (int a, int b) {
int ans = 0;
function<void(int&)> lift = [&] (int &u) {
ans |= querySparse(num[chain[u]], num[u]);
u = par[chain[u]];
};
while (chain[a] != chain[b]) {
int ap = par[chain[a]], bp = par[chain[b]];
if (depth[ap] == depth[bp]) lift(a), lift(b);
else if (depth[ap] > depth[bp]) lift(a);
else if (depth[bp] > depth[ap]) lift(b);
}
if (depth[a] > depth[b]) swap(a, b);
return ans | querySparse(num[a] + 1, num[b]);
}
bool isSub (int sub, int mask) {
return (sub & mask) == sub;
}
bool check (int u, int v, int mask) {
return isSub(getMask(u, v), mask);
}
int dist (int u, int v) {
return depth[u] + depth[v] - 2 * depth[lca(u, v)];
}
int main()
{
ios::sync_with_stdio(0);
cin.tie(0);
/// read input and preprocess
int n, S, T, counter = 0;
cin >> n >> S >> T;
vector<tpl> locked;
for (int i = 1; i < n; i++) {
int a, b, c; cin >> a >> b >> c;
adj[a].push_back(b);
adj[b].push_back(a);
if (c) {
locked.emplace_back(a, b, c);
counter++;
}
}
preprocess::dfs(1, 0, 1, 0);
/// build sparse table
for (int i = 0; i < counter; i++) {
int a, b, c; tie(a, b, c) = locked[i];
if (depth[a] < depth[b]) swap(a, b);
spt[num[a]][0] = 1 << i;
}
for (int s = 1; (1 << s) <= n; s++) {
int p = s - 1;
for (int i = 1; i + (1 << s) - 1 <= n; i++)
spt[i][s] = spt[i][p] | spt[i + (1 << p)][p];
}
/// pre-calculations
for (int i = 0; i < counter; i++) {
int c = get<2>(locked[i]);
for (int j = 0; j < counter; j++) {
int u, v, w; tie(u, v, w) = locked[j];
preDist[i][j] = dist(c, u), preDist[i][j + counter] = dist(c, v);
preMask[i][j] = getMask(c, u), preMask[i][j + counter] = getMask(c, v);
}
}
/// prepare DP base-cases
if (check(S, T, 0)) return cout << dist(S, T) << "\n", 0;
for (int mask = 0; mask < (1 << counter); mask++)
fill(dp[mask], dp[mask] + 2 * counter, INT_MAX);
for (int i = 0; i < counter; i++) {
int a, b, c; tie(a, b, c) = locked[i];
if (!check(S, c, 0)) continue;
// from S to c to either a or b
int mask = 1 << i;
if (check(c, a, 0)) dp[mask][i] = dist(S, c) + preDist[i][i];
if (check(c, b, 0)) dp[mask][i + counter] = dist(S, c) + preDist[i][i + counter];
}
/// run bitmask DP
int full = (1 << counter) - 1, ans = INT_MAX;
for (int mask = 1; mask < (1 << counter); mask++) {
vector<int> inMask, outMask;
for (int i = 0; i < counter; i++) {
if (mask >> i & 1) inMask.push_back(i);
else outMask.push_back(i);
}
for (int last : inMask) {
int a, b, c; tie(a, b, c) = locked[last];
for (int nxt : outMask) {
int u, v, w; tie(u, v, w) = locked[nxt];
int nxtMask = mask | (1 << nxt);
if (dp[mask][last] != INT_MAX && isSub(preMask[nxt][last], mask)) { // from a to w to either u or v
int cur = dp[mask][last] + preDist[nxt][last];
if (isSub(preMask[nxt][nxt], mask)) // w to u
dp[nxtMask][nxt] = min(dp[nxtMask][nxt], cur + preDist[nxt][nxt]);
if (isSub(preMask[nxt][nxt + counter], mask)) // w to v
dp[nxtMask][nxt + counter] = min(dp[nxtMask][nxt + counter], cur + preDist[nxt][nxt + counter]);
}
if (dp[mask][last + counter] != INT_MAX && isSub(preMask[nxt][last + counter], mask)) { // from b to w to either u or v
int cur = dp[mask][last + counter] + preDist[nxt][last + counter];
if (isSub(preMask[nxt][nxt], mask)) // w to u
dp[nxtMask][nxt] = min(dp[nxtMask][nxt], cur + preDist[nxt][nxt]);
if (isSub(preMask[nxt][nxt + counter], mask)) // w to v
dp[nxtMask][nxt + counter] = min(dp[nxtMask][nxt + counter], cur + preDist[nxt][nxt + counter]);
}
}
// update answer
if (dp[mask][last] != INT_MAX && check(a, T, mask))
ans = min(ans, dp[mask][last] + dist(a, T));
if (dp[mask][last + counter] != INT_MAX && check(b, T, mask))
ans = min(ans, dp[mask][last + counter] + dist(b, T));
// if (dp[mask][last] != INT_MAX) cout << "dist " << a << " " << bitset<3>(mask) << " " << dp[mask][last] << "\n";
// if (dp[mask][last + counter] != INT_MAX) cout << "dist " << b << " " << bitset<3>(mask) << " " << dp[mask][last + counter] << "\n";
}
}
cout << (ans == INT_MAX ? -1 : ans) << "\n";
return 0;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |