#include <bits/stdc++.h>
using namespace std;
#define ll long long
const int N = 3e5 + 10;
const int M = 1e4;
const long long inf = 1e17;
int n;
vector<int> g[N];
int x, y;
int m, path[N]; // path from x to y
bool dfs(int u, int p, int d) {
if (u == y) {
path[d] = y;
m = d;
return true;
}
for (int v : g[u]) if (v != p) {
if (dfs(v, u, d + 1)) {
path[d] = u;
return true;
}
}
return false;
}
int dp[N];
void DFS(int u, int p, int b) {
vector<int> s;
for (int v : g[u]) if (v != p && v != b) {
DFS(v, u, b);
s.push_back(dp[v]);
}
sort(s.begin(), s.end());
int k = s.size();
for (int i = 0; i < k; i ++) {
dp[u] = max(dp[u], s[i] + k - i);
}
}
int solve(int i) {
memset(dp, 0, sizeof dp);
DFS(x, x, path[i + 1]);
DFS(y, y, path[i]);
return max(dp[x], dp[y]);
}
int main() {
//ios_base::sync_with_stdio(0);
//cin.tie(0);
//cout.tie(0);
cin >> n >> x >> y;
int u, v;
for (int i = 1; i < n; i ++) {
cin >> u >> v;
g[u].push_back(v);
g[v].push_back(u);
}
dfs(x, x, 1);
int l = 1;
int r = m - 1;
while (l < r) {
int mid = (l + r) / 2;
solve(mid + 1);
if (dp[x] <= dp[y]) l = mid + 1;
else r = mid;
}
int ans = solve(l);
if (l < m - 1) ans = min(ans, solve(l + 1));
cout << ans << "\n";
return 0;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |