#include <bits/stdc++.h>
using namespace std;
#define endl '\n'
#define ll long long
#define all(x) (x).begin(), (x).end()
const int mxn = 3e5 + 10;
vector<int> g[mxn];
vector<pair<int, int>> edg;
int n, a, b;
int dist[mxn], from[mxn];
void dfs(int cur, int to) {
for (int i = 0; i < mxn; i++) dist[i] = 1e6;
dist[cur] = 0;
queue<int> q;
q.push(cur);
while (q.size()) {
int fr = q.front();
q.pop();
for (auto x : g[fr]) {
if (dist[fr] + 1 < dist[x]) {
dist[x] = dist[fr] + 1;
from[x] = fr;
q.push(x);
}
}
}
while (to != cur) {
edg.push_back({from[to], to});
to = from[to];
}
reverse(all(edg));
}
pair<int, int> disable = {-1, -1};
int solve(int cur, int par = -1) {
vector<int> times;
for (auto x : g[cur]) {
if (x == par || make_pair(cur, x) == disable || make_pair(x, cur) == disable) continue;
times.push_back(solve(x, cur));
}
int time = 0;
sort(all(times));
reverse(all(times));
int cnt = 0;
for (auto x : times) {
time = max(time, x + cnt);
cnt++;
}
return time + (par != -1);
}
bool check(int edge) {
disable = edg[edge];
int costA = solve(a), costB = solve(b);
return costA <= costB;
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
cin >> n >> a >> b;
for (int i = 0; i < n - 1; i++) {
int f, t;
cin >> f >> t;
g[f].push_back(t);
g[t].push_back(f);
}
dfs(a, b);
int l = 0, r = edg.size();
while (l + 1 < r) {
int mid = (l + r) / 2;
if (check(mid)) l = mid;
else r = mid;
}
int ans = 1e9;
disable = edg[l];
ans = min(ans, max(solve(a), solve(b)));
if (l + 1 < edg.size()) {
disable = edg[l + 1];
ans = min(ans, max(solve(a), solve(b)));
}
cout << ans << endl;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |