#include<bits/stdc++.h>
using namespace std;
using ll = long long;
const int N = 300300;
vector<int> adj[N], adj2[N];
vector<int> nodes;
bool vis[N], in[N];
int X, Y;
// minimum time to fill in subtree when node i is active
int dp[N], pref[N], suff[N];
bool find_path(int v) {
vis[v] = 1;
if (v == Y) {
in[v] = 1;
nodes.push_back(v);
return true;
}
for (auto& x : adj2[v]) {
if (vis[x]) continue;
if (find_path(x)) {
in[v] = 1;
nodes.push_back(v);
return true;
}
}
return false;
}
void dfs1(int v, int p = -1) {
for (auto& x : adj2[v]) {
if (in[x] || x == p) continue;
adj[v].push_back(x);
dfs1(x, v);
}
}
void dfs(int v) {
for (auto& x : adj[v]) dfs(x);
dp[v] = 0;
sort(adj[v].begin(), adj[v].end(),
[&](int a, int b) {return dp[a] > dp[b];});
int c = 0;
for (auto& x : adj[v]) {
dp[v] = max(dp[v], dp[x] + (++c));
}
}
int main() {
cin.tie(nullptr)->sync_with_stdio(false);
int n;
cin >> n >> X >> Y;
for (int i = 1;i < n;i++) {
int u, v;
cin >> u >> v;
adj2[u].push_back(v);
adj2[v].push_back(u);
}
// find node in path between (x, y)
find_path(X);
reverse(nodes.begin(), nodes.end());
// calculate dp
for (int i = 1;i <= n;i++) {
if (in[i]) dfs1(i), dfs(i);
}
for (int i = 1;i <= n;i++) {
adj2[i].clear();
//cout << dp[i] << ' ';
}
//cout << '\n';
int l = max(dp[X], dp[Y]), r = n - 1;
int sz = nodes.size();
if (sz == 2) {
cout << l << '\n';
return 0;
}
while (l < r) {
int mid = (l + r) / 2;
int l1, r1;
// find maximum possible index x can reach
l1 = 0, r1 = sz - 2;
while (l1 < r1) {
int mid1 = (l1 + r1 + 1) / 2;
{
int v = nodes[mid1];
dp[v] = 0;
sort(adj[v].begin(), adj[v].end(),
[&](int a, int b) {return dp[a] > dp[b];});
int c = 0;
for (auto& x : adj[v]) {
dp[v] = max(dp[v], dp[x] + (++c));
}
}
for (int j = mid1 - 1;j >= 0;j--) {
int u = nodes[j], v = nodes[j + 1];
dp[u] = 0;
adj2[u] = adj[u];
adj2[u].push_back(v);
sort(adj2[u].begin(), adj2[u].end(),
[&](int a, int b) {return dp[a] > dp[b];});
int c = 0;
for (auto& x : adj2[u]) {
dp[u] = max(dp[u], dp[x] + (++c));
}
adj2[u].clear();
}
if (dp[X] <= mid) l1 = mid1;
else r1 = mid1 - 1;
}
int ansx = l1;
// find minimum possible index y can reach
l1 = 1, r1 = sz - 1;
while (l1 < r1) {
int mid1 = (l1 + r1) / 2;
{
int v = nodes[mid1];
dp[v] = 0;
sort(adj[v].begin(), adj[v].end(),
[&](int a, int b) {return dp[a] > dp[b];});
int c = 0;
for (auto& x : adj[v]) {
dp[v] = max(dp[v], dp[x] + (++c));
}
}
for (int j = mid1 + 1;j < sz;j++) {
int u = nodes[j], v = nodes[j - 1];
dp[u] = 0;
adj2[u] = adj[u];
adj2[u].push_back(v);
sort(adj2[u].begin(), adj2[u].end(),
[&](int a, int b) {return dp[a] > dp[b];});
int c = 0;
for (auto& x : adj2[u]) {
dp[u] = max(dp[u], dp[x] + (++c));
}
adj2[u].clear();
}
if (dp[Y] <= mid) r1 = mid1;
else l1 = mid1 + 1;
}
int ansy = l1;
//cout << mid << ' ' << ansx << ' ' << ansy << '\n';
// shift (l, r)
if (ansx + 1 >= ansy) r = mid;
else l = mid + 1;
}
cout << l << '\n';
return 0;
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
7 ms |
14424 KB |
Output is correct |
2 |
Correct |
6 ms |
14428 KB |
Output is correct |
3 |
Correct |
7 ms |
14428 KB |
Output is correct |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
307 ms |
30828 KB |
Output is correct |
2 |
Correct |
424 ms |
35412 KB |
Output is correct |
3 |
Correct |
713 ms |
35724 KB |
Output is correct |
4 |
Correct |
495 ms |
36152 KB |
Output is correct |
5 |
Correct |
296 ms |
35152 KB |
Output is correct |
6 |
Correct |
521 ms |
35412 KB |
Output is correct |
7 |
Correct |
826 ms |
36756 KB |
Output is correct |