#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define ull unsigned long long
#define pb(e) push_back(e)
#define sv(a) sort(a.begin(),a.end())
#define sa(a,n) sort(a,a+n)
#define mp(a,b) make_pair(a,b)
#define vf first
#define vs second
#define ar array
#define all(x) x.begin(),x.end()
const int inf = 0x3f3f3f3f;
const int mod = 1000000007;
const double PI=3.14159265358979323846264338327950288419716939937510582097494459230;
bool remender(ll a , ll b){return a%b;}
const int N = 300006;
vector<int> adj[N];
bool vis[N];
int par[N];
int mx[N];
vector<int> v[N];
int a , b;
int adj1[N];
void dfss(int node){
vis[node] = 1;
mx[node] = 0;
for(int i : adj[node]){
if(!vis[i] && i != par[node]){
dfss(i);
v[node].pb(mx[i]);
}
}
sv(v[node]);
int chances = 0;
for(int i = v[node].size()-1; i >= 0; i--){
chances++;
mx[node] = max(mx[node] , chances + v[node][i]);
}
}
void dfs(int node){
vis[node] = 1;
for(int i : adj[node]){
if(!vis[i]){
par[i] = node;
dfs(i);
}
}
}
void dfs1(int node , int done , int total){
if(node == a)return;
if(done + mx[node] > total)return;
vis[node] = 1;
if(v[node].size() == 0){
if(vis[par[node]] == 0)dfs1(par[node] , done + 1 , total);
return;
}
int p[v[node].size()];
p[0] = v[node][0] + v[node].size();
for(int i = 1; i < (int) v[node].size(); i++){
p[i] = max(p[i-1] , v[node][i] + (int)v[node].size() - (i));
}
int cur = done;
for(int i = v[node].size() - 1; i >= 0 && done < total; i--){
if(p[i] + 1 + cur <= total && vis[par[node]] == 0){
dfs1(par[node] , done + 1 , total);
return;
}
done++;
}
if(done < total)dfs1(par[node] , done + 1 , total);
}
void dfs2(int node , int done , int total){
if(node == b)return;
if(done + mx[node] > total)return;
vis[node] = 1;
int nex = adj1[node];
if(vis[nex])return;
if(v[node].size() == 0){
dfs2(nex , done + 1 , total);
return;
}
int p[v[node].size()];
p[0] = v[node][0] + v[node].size();
for(int i = 1; i < (int) v[node].size(); i++){
p[i] = max(p[i-1] , v[node][i] + (int)v[node].size() - (i));
}
int cur = done;
for(int i = v[node].size() - 1; i >= 0 && done < total; i--){
if(p[i] + 1 + cur <= total && vis[nex] == 0){
dfs2(nex , done + 1 , total);
return;
}
done++;
}
if(done < total)dfs2(nex , done + 1 , total);
}
void solve(){
int n;
cin >> n;
cin >> a >> b;
for(int i = 0; i < n - 1; i++){
int x , y;
cin >> x >> y;
adj[x].pb(y);
adj[y].pb(x);
}
memset(vis,0,sizeof vis);
dfs(a);
par[a] = -1;
vector<int> path;
int cur = b;
while(cur != a){
cur = par[cur];
if(cur == a)break;
path.pb(cur);
}
if(path.size() == 0)adj1[a] = b;
else {
adj1[a] = path.back();
adj1[path[0]] = b;
for(int i = 0; i < (int)path.size() - 1; i++){
adj1[path[i + 1]] = path[i];
}
}
memset(vis,0,sizeof vis);
dfss(b);
for(int i : path)dfss(i);
dfss(a);
int l = max(mx[a] , mx[b]) , r = n;
//cout << l << '\n';
int ans = - 1;
while(l <= r){
int mid = (l + r)/2;
memset(vis,0,sizeof vis);
dfs1(b , 0 , mid);
dfs2(a , 0 , mid);
int possible = 1;
for(int i : path){
if(vis[i] == 0)possible = 0;
}
if(possible == 1){
ans = mid;
r = mid - 1;
}
else l = mid + 1;
}
cout << ans << '\n';
}
int main(){
ios_base::sync_with_stdio(false);
cin.tie(NULL);
//freopen("dec.in", "r", stdin);
//freopen("dec.out", "w", stdout);
//int t;cin >> t;while(t--)
solve();
return 0;
}
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Correct |
9 ms |
14668 KB |
Output is correct |
2 |
Correct |
9 ms |
14676 KB |
Output is correct |
3 |
Correct |
8 ms |
14680 KB |
Output is correct |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Correct |
117 ms |
32384 KB |
Output is correct |
2 |
Correct |
152 ms |
37548 KB |
Output is correct |
3 |
Correct |
132 ms |
38728 KB |
Output is correct |
4 |
Correct |
137 ms |
38500 KB |
Output is correct |
5 |
Correct |
130 ms |
36936 KB |
Output is correct |
6 |
Correct |
124 ms |
37516 KB |
Output is correct |
7 |
Correct |
129 ms |
39788 KB |
Output is correct |