#include <bits/stdc++.h>
#define fi first
#define se second
#define mp make_pair
using namespace std;
const int N = 3e5 + 5;
vector <int> adj[N];
vector <int> node;
int par[N];
int dp[N];
int c[N];
int n,a,b;
void dfs(int u, int p) {
par[u] = p;
for (int v : adj[u])
if (v != p)
dfs(v, u);
}
void solve(int u, int p, int clr) {
dp[u] = 0;
vector <int> val;
for (int v : adj[u])
if (v != p && c[v] != clr)
solve(v, u, clr), val.push_back(dp[v]);
sort(val.begin(), val.end(), greater <int> ());
for (int i = 0; i < (int) val.size(); i++)
dp[u] = max(dp[u], val[i] + i + 1);
}
int main() {
ios_base::sync_with_stdio(0);
cin.tie(0);
cin >> n >> a >> b;
for (int i = 1; i < n; i++) {
int u,v;
cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
dfs(a, -1);
for (int i = b; i != -1; i = par[i])
node.push_back(i);
reverse(node.begin(), node.end());
int lo = 0;
int hi = node.size() - 1;
while (hi - lo > 1) {
int mid = (lo + hi) >> 1;
for (int i = 0; i <= mid; i++)
c[node[i]] = 1;
for (int i = mid + 1; i < (int) node.size(); i++)
c[node[i]] = 2;
solve(a, -1, 2);
solve(b, -1, 1);
if (dp[a] < dp[b])
lo = mid;
else
hi = mid;
}
for (int i = 0; i <= lo; i++)
c[node[i]] = 1;
for (int i = hi; i < (int) node.size(); i++)
c[node[i]] = 2;
solve(b, -1, 1);
solve(a, -1, 2);
int ans = max(dp[a], dp[b]);
if (hi < (int) node.size() - 1) {
c[node[hi]] = 1;
solve(a, -1, 2);
ans = min(ans, dp[a]);
}
cout << ans;
return 0;
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
5 ms |
7372 KB |
Output is correct |
2 |
Correct |
5 ms |
7372 KB |
Output is correct |
3 |
Correct |
6 ms |
7368 KB |
Output is correct |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
540 ms |
25992 KB |
Output is correct |
2 |
Correct |
581 ms |
27236 KB |
Output is correct |
3 |
Correct |
545 ms |
29040 KB |
Output is correct |
4 |
Correct |
527 ms |
28368 KB |
Output is correct |
5 |
Correct |
518 ms |
25688 KB |
Output is correct |
6 |
Correct |
447 ms |
26264 KB |
Output is correct |
7 |
Correct |
457 ms |
29548 KB |
Output is correct |