#include <bits/stdc++.h>
#define int long long
// #define ll long long
#define pii pair<int, int>
#define all(v) v.begin(), v.end()
using namespace std;
const int oo = 1e18 + 9;
const int MAX = 1e6 + 5, LOGMAX = 20, B = 800, MOD = 1e9 + 7;
vector<array<int, 3>> d;
int n, S, T;
struct DSU{
int par[MAX];
void init(){
memset(par, -1, sizeof(par));
}
int get(int u){
if(par[u] < 0) return u;
return get(par[u]);
}
stack<array<int, 4>> st;
void unite(int u, int v, bool b){
u = get(u), v = get(v);
if(u == v) return;
if(-par[u] < -par[v]) swap(u, v);
if(b) st.push({u, v, par[u], par[v]});
par[u] += par[v];
par[v] = u;
}
bool same(int u, int v){
return get(u) == get(v);
}
void roll_back(){
auto a = st.top();
st.pop();
par[a[0]] = a[2];
par[a[1]] = a[3];
}
};
DSU dsu;
vector<int> g[MAX];
int in[MAX], out[MAX], par[LOGMAX][MAX], H[MAX];
int t = 0;
void dfs(int node, int p, int h){
H[node] = h;
in[node] = ++t;
for(int to : g[node]){
if(to == p) continue;
dfs(to, node, h + 1);
}
out[node] = t;
}
bool isA(int u, int v){
return in[u] <= in[v] && out[u] >= out[v];
}
int dist(int u, int v){
int l = u;
if(isA(u, v)) return H[v] - H[u];
if(isA(v, u)) return H[u] - H[v];
for(int j = LOGMAX - 1; j >= 0; j--){
if(!isA(par[j][l], v)) l = par[j][l];
}
l = par[0][l];
return H[u] + H[v] - 2 * H[l];
}
void solve(){
cin >> n >> S >> T;
dsu.init();
for(int i = 1; i < n; i++){
int u, v, w; cin >> u >> v >> w;
if(w) d.push_back({u, v, w});
else dsu.unite(u, v, 0);
g[u].push_back(v);
g[v].push_back(u);
}
dfs(1, 1, 0);
for(int j = 1; j < LOGMAX; j++){
for(int i = 1; i <= n; i++){
par[j][i] = par[j - 1][par[j - 1][i]];
}
}
int k = d.size();
int ans = oo;
sort(all(d));
do{
int cur = S;
int D = 0;
bool b = 1;
for(auto a : d){
if(dsu.same(cur, T)){
D += dist(cur, T);
break;
}
if(!dsu.same(cur, a[2])){
b = 0;
break;
}
D += dist(cur, a[2]);
cur = a[2];
if(dsu.same(cur, a[1])){
D += dist(cur, a[1]);
dsu.unite(a[0], a[1], 1);
cur = a[1];
}
else if(dsu.same(cur, a[0])){
D += dist(cur, a[0]);
dsu.unite(a[0], a[1], 1);
cur = a[0];
}
else{
b = 0;
break;
}
}
while(dsu.st.size()) dsu.roll_back();
if(!b) continue;
ans = min(ans, D);
}
while(next_permutation(all(d)));
cout << ans + 1 << '\n';
}
signed main(){
// ios::sync_with_stdio(0);
// cin.tie(0);
// cout.tie(0);
int t = 1;
while(t--) solve();
}
Compilation message
Main.cpp: In function 'void solve()':
Main.cpp:88:9: warning: unused variable 'k' [-Wunused-variable]
88 | int k = d.size();
| ^
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
9 ms |
76120 KB |
Output is correct |
2 |
Correct |
9 ms |
76124 KB |
Output is correct |
3 |
Incorrect |
71 ms |
96540 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
9 ms |
76124 KB |
Output is correct |
2 |
Correct |
9 ms |
76124 KB |
Output is correct |
3 |
Incorrect |
169 ms |
76168 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
9 ms |
76120 KB |
Output is correct |
2 |
Correct |
9 ms |
76124 KB |
Output is correct |
3 |
Incorrect |
71 ms |
96540 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |