#include <bits/stdc++.h>
#define ll long long
#define endl "\n"
using namespace std;
mt19937_64 rng(chrono::steady_clock::now().time_since_epoch().count());
const ll N = 1e6 + 5, lg = 22;
vector<pair<ll, ll>> g[N];
ll need[N], comp[N], depth[N], up[lg][N];
ll vals[lg], dp[1 << lg][lg];
ll freq[lg][N];
void dfs(ll v, ll par = 0)
{
up[0][v] = par;
for (ll i = 1; i < lg; i++)
up[i][v] = up[i - 1][up[i - 1][v]];
for (auto [to, w] : g[v])
if (to != par)
{
for (ll i = 0; i < lg; i++)
freq[i][to] = freq[i][v];
if (w)
freq[comp[w]][to]++;
depth[to] = depth[v] + 1;
dfs(to, v);
}
}
ll lca(ll a, ll b)
{
if (depth[a] < depth[b])
swap(a, b);
for (ll bit = lg - 1; bit >= 0; bit--)
if (depth[a] - (1 << bit) >= depth[b])
a = up[bit][a];
if (a == b)
return a;
for (ll bit = lg - 1; bit >= 0; bit--) if (up[bit][a] != up[bit][b]) a = up[bit][a], b = up[bit][b];
return up[0][a];
}
ll dist(ll a, ll b)
{
return depth[a] + depth[b] - depth[lca(a, b)] * 2;
}
ll needs(ll a, ll b)
{
ll lc = lca(a, b), ans = 0;
for (ll i = 0; i < lg; i++)
if (freq[i][a] + freq[i][b] - freq[i][lc] * 2)
ans |= 1 << i;
return ans;
}
void solve()
{
ll n;
cin >> n;
ll s, t;
cin >> s >> t;
array<ll, 3> e[n - 1];
for (ll i = 0; i < n - 1; i++)
{
cin >> e[i][0] >> e[i][1] >> e[i][2];
g[e[i][0]].push_back(make_pair(e[i][1], e[i][2]));
g[e[i][1]].push_back(make_pair(e[i][0], e[i][2]));
}
ll sz = 0;
{
set<ll> tmp;
for (ll i = 1; i <= n; i++)
for (auto [j, w] : g[i])
if (w)
tmp.insert(w);
tmp.insert(s);
tmp.insert(t);
for (ll i : tmp) comp[i] = sz, vals[sz++] = i;
}
dfs(s, s);
for (ll i = 0; i < (1 << sz); i++)
for (ll j = 0; j < sz; j++)
dp[i][j] = 1e18;
dp[1 << comp[s]][comp[s]] = 0;
for (ll msk = 0; msk < (1 << sz); msk++) if (msk != (1 << comp[s]))
{
for (ll i = 0; i < sz; i++)
if (msk >> i & 1)
for (ll j = 0, x; j < sz; j++)
{
x = needs(vals[i], vals[j]);
if ((msk >> j & 1) and j != i and ((msk ^ (1 << i)) & x) == x)
dp[msk][i] = min(dp[msk][i], dp[msk ^ (1 << i)][j] + dist(vals[i], vals[j]));
}
}
ll ans = 1e18;
for (ll msk = 0; msk < (1 << sz); msk++)
ans = min(ans, dp[msk][comp[t]]);
if (ans == 1e18)
ans = -1;
cout << ans << endl;
}
int main()
{
ios_base::sync_with_stdio(0);
cin.tie(0);
ll t = 1;
// precomp();
// cin >> t;
for (ll cs = 1; cs <= t; cs++)
solve();
// cerr << "\nTime elapsed: " << clock() * 1000.0 / CLOCKS_PER_SEC << " ms\n";
}
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Correct |
9 ms |
73564 KB |
Output is correct |
2 |
Correct |
8 ms |
73780 KB |
Output is correct |
3 |
Correct |
75 ms |
134052 KB |
Output is correct |
4 |
Incorrect |
79 ms |
133456 KB |
Output isn't correct |
5 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Correct |
9 ms |
73564 KB |
Output is correct |
2 |
Correct |
9 ms |
73564 KB |
Output is correct |
3 |
Incorrect |
24 ms |
75900 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Correct |
9 ms |
73564 KB |
Output is correct |
2 |
Correct |
8 ms |
73780 KB |
Output is correct |
3 |
Correct |
75 ms |
134052 KB |
Output is correct |
4 |
Incorrect |
79 ms |
133456 KB |
Output isn't correct |
5 |
Halted |
0 ms |
0 KB |
- |