#include <iostream>
#include <vector>
using namespace std;
#define ixi pair <int *, int>
const int N = 1e6 + 1;
const int M = 1 << 20;
const int INF = 1e9;
int n, m, st, en;
int x, y, w, in[N], out[N];
int d[N], tim, E[N << 1][22], cnt;
int Log2[N << 1], depth[N], nd[N];
int dp[M][20][2];
vector <int> g[N], on, off;
struct DSU {
int par[N];
vector <ixi> his;
DSU () {
his.clear();
fill_n(par, N, -1);
}
int FindPar(int u, bool roll) {
if (par[u] < 0) return u;
if (!roll) {
return par[u] = FindPar(par[u], roll);
} else {
return FindPar(par[u], roll);
}
}
void unite(int u, int v, bool roll) {
u = FindPar(u, roll);
v = FindPar(v, roll);
if (roll) {
his.push_back(ixi(&par[u], par[u]));
his.push_back(ixi(&par[v], par[v]));
}
if (par[u] > par[v]) swap(u, v);
par[u] += par[v];
par[v] = u;
}
bool same(int u, int v, bool roll) {
return FindPar(u, roll) == FindPar(v, roll);
}
void roll() {
while (!his.empty()) {
*his.back().first = his.back().second;
his.pop_back();
}
}
} dsu;
struct obj {
int u[2], h;
obj () {}
obj (int U, int V, int H) {
u[0] = U;
u[1] = V;
h = H;
}
};
vector <obj> e;
void DFS(int s = 1, int p = -1) {
d[s] = ++tim;
nd[tim] = s;
E[++cnt][0] = d[s];
in[s] = out[s] = cnt;
for (int z: g[s]) {
if (z == p) continue;
depth[z] = depth[s] + 1;
DFS(z, s);
E[++cnt][0] = d[s];
out[s] = cnt;
}
}
int LCA(int A, int B) {
if (in[A] > in[B]) swap(A, B);
int k = Log2[out[B] - in[A] + 1];
return nd[min(E[in[A]][k], E[out[B] - (1 << k) + 1][k])];
}
int dis(int A, int B) {
return depth[A] + depth[B] - 2 * depth[LCA(A, B)];
}
int main() {
/*freopen("TREEMAZE.INP", "r", stdin);
freopen("TREEMAZE.OUT", "w", stdout);*/
cin.tie(0); cout.tie(0);
ios_base::sync_with_stdio(false);
for (int i = 2; i < (N << 1); ++i) {
Log2[i] = Log2[i >> 1] + 1;
}
cin >> n >> m >> st >> en;
for (int i = 1; i < n; ++i) {
cin >> x >> y >> w;
g[x].push_back(y);
g[y].push_back(x);
if (w == 0) {
dsu.unite(x, y, false);
} else {
e.push_back(obj(x, y, w));
}
}
DFS();
for (int j = 1; (1 << j) <= cnt; ++j) {
for (int i = 1; i + (1 << j) - 1 <= cnt; ++i) {
E[i][j] = min(E[i][j - 1], E[i + (1 << (j - 1))][j - 1]);
}
}
if (dsu.same(st, en, false)) {
cout << dis(st, en) << "\n";
return 0;
}
for (int i = 0; i < (1 << m); ++i) {
for (int j = 0; j < m; ++j) {
dp[i][j][0] = dp[i][j][1] = INF;
}
}
for (int i = 0; i < m; ++i) {
if (dsu.same(st, e[i].h, false)) {
for (int j = 0; j < 2; ++j) {
if (dsu.same(st, e[i].u[j], false)) {
dp[(1 << i)][i][j] = dis(st, e[i].h) + dis(e[i].h, e[i].u[j]);
}
}
}
}
for (int i = 1; i <= n; ++i) dsu.FindPar(i, false);
int ans = INF;
for (int i = 1; i < (1 << m); ++i) {
on.clear(); off.clear();
dsu.roll();
for (int j = 0; j < m; ++j) {
if (i >> j & 1) {
on.push_back(j);
dsu.unite(e[j].u[0], e[j].u[1], true);
} else off.push_back(j);
}
if (dsu.same(st, en, true)) {
for (int z: on) {
for (int j = 0; j < 2; ++j) {
ans = min(ans, dp[i][z][j] + dis(e[z].u[j], en));
}
}
}
for (int A: on) {
for (int ita = 0; ita < 2; ++ita) {
if (dp[i][A][ita] == INF) continue;
for (int B: off) {
for (int itb = 0; itb < 2; ++itb) {
if (dsu.same(e[A].u[ita], e[B].h, true) && dsu.same(e[B].u[itb], e[B].h, true)) {
dp[i | (1 << B)][B][itb] = min(dp[i | (1 << B)][B][itb], dp[i][A][ita] + dis(e[A].u[ita], e[B].h) + dis(e[B].h, e[B].u[itb]));
}
}
}
}
}
}
if (ans == INF) ans = -1;
cout << ans << "\n";
return 0;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |