#include <bits/stdc++.h>
#define pb push_back
#define F first
#define S second
#define debug(x) cout << #x << "= " << x << ", "
#define ll long long
#define fast ios::sync_with_stdio(false), cin.tie(0), cout.tie(0)
#define SZ(x) (int) x.size()
#define wall cout << endl;
using namespace std;
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
const int maxn = 1e6 + 10;
int n , dp[maxn] , m , t , par[maxn];
vector <int> adj[maxn];
bool vis[maxn];
struct amirhoseinfar1385team{
int cnt[(maxn << 2)] , col[(maxn << 2)];
void Add(int ind , int node = 1 , int nl = 0 , int nr = n)
{
cnt[node]++;
if(nr - nl == 1) return;
int mid = (nr + nl) >> 1 , lc = node << 1 , rc = lc | 1;
if(ind < mid) Add(ind , lc , nl , mid);
else Add(ind , rc , mid , nr);
}
int Find_mx(int num , int node = 1 , int nl = 0 , int nr = n)
{
if(nr - nl == 1) return nl;
int mid = (nl + nr) >> 1 , lc = node << 1 , rc = lc | 1;
if(num <= cnt[rc]) return Find_mx(num , rc , mid , nr);
else return Find_mx(num - cnt[rc] , lc , nl , mid);
}
int Find_col(int ind , int node = 1 , int nl = 0 , int nr = n)
{
if(nr - nl == 1) return col[node];
int mid = (nr + nl) >> 1 , lc = node << 1 , rc = lc | 1;
if(ind < mid) return col[rc] + Find_col(ind , lc , nl , mid);
else return Find_col(ind , rc , mid , nr);
}
void Add_col(int ind , int node = 1 , int nl = 0 , int nr = n)
{
col[node]++;
if(nr - nl == 1) return;
int mid = (nr + nl) >> 1 , lc = node << 1 , rc = lc | 1;
if(ind < mid) Add_col(ind , lc , nl , mid);
else Add_col(ind , rc , mid , nr);
}
} seg;
void dfs(int v)
{
vector <int> vec = {0 , 0};
for(auto u : adj[v]) if(u != par[v])
{
par[u] = v;
dfs(u);
vec.pb(dp[u]);
dp[v]++;
}
sort(vec.rbegin() , vec.rend());
dp[v] += vec[1];
}
int32_t main()
{
fast;
cin >> n >> t >> m;
for(int i = 0 ; i < n - 1 ; i++)
{
int v , u; cin >> v >> u;
adj[v].pb(u);
adj[u].pb(v);
}
dfs(t);
int now = m , sum_deg = 0 , prev = 1;
vis[t] = true;
vector <int> path;
while(now != t)
{
vis[now] = true;
path.pb(now);
sum_deg += (SZ(adj[now]) - prev);
prev = 2;
now = par[now];
}
int ans = 0;
for(int i = 0 ; i < SZ(path) ; i++)
{
int deg_now = 0 , mxc = -1 , v = path[i];
for(auto u : adj[v]) if(!vis[u])
{
deg_now++;
seg.Add(dp[u]);
mxc = max(mxc , dp[u]);
}
int mx = 0;
if(i + 2 <= seg.cnt[1]) mx = seg.Find_mx(i + 2);
int tmp = seg.Find_col(mx);
ans = max(ans , sum_deg + tmp + mx);
sum_deg -= deg_now;
seg.Add_col(mxc);
}
cout << ans << endl;
return 0;
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
11 ms |
23816 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
317 ms |
76456 KB |
Output is correct |
2 |
Correct |
294 ms |
71132 KB |
Output is correct |
3 |
Correct |
777 ms |
77292 KB |
Output is correct |
4 |
Correct |
337 ms |
50392 KB |
Output is correct |
5 |
Correct |
756 ms |
77352 KB |
Output is correct |
6 |
Correct |
775 ms |
77212 KB |
Output is correct |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
11 ms |
23816 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
11 ms |
23816 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |