Submission #563755

#TimeUsernameProblemLanguageResultExecution timeMemory
563755ngpin04Mousetrap (CEOI17_mousetrap)C++14
100 / 100
815 ms145056 KiB
#include <bits/stdc++.h>
#define fi first
#define se second
#define mp make_pair
#define TASK ""
#define bit(x) (1LL << (x))
#define getbit(x, i) (((x) >> (i)) & 1)
#define ALL(x) (x).begin(), (x).end() 
using namespace std;
template <typename T1, typename T2> bool mini(T1 &a, T2 b) {
	if (a > b) {a = b; return true;} return false;
}
template <typename T1, typename T2> bool maxi(T1 &a, T2 b) {
	if (a < b) {a = b; return true;} return false;
}
mt19937_64 rd(chrono::steady_clock::now().time_since_epoch().count());

int rand(int l, int r) {
	return l + rd() % (r - l + 1);
}
const int N = 1e6 + 5; 
const int oo = 1e9;
const long long ooo = 1e18;
const int mod = 1e9 + 7; // 998244353;
const long double pi = acos(-1);

vector <int> adj[N];
vector <int> a;

int deg[N];
int par[N];
int dp[N];
int n,t,s;

bool flag[N];

void dfs(int u, int p = -1) {
	par[u] = p;
	for (int v : adj[u]) 
		if (v != p) 
			dfs(v, u);
}

void solve(int u, int p = -1) {
	deg[u] = 0;
	pair <int, int> pir = mp(0, 0);
	for (int v : adj[u]) if (v != p) {
		solve(v, u);
		deg[u]++;
		if (maxi(pir.se, dp[v])) {
			if (pir.fi < pir.se)
				swap(pir.fi, pir.se);
		}
	}
	dp[u] = pir.se + deg[u];
}

bool check(int lim) {
	int tot = 0;
	for (int v : a)
		tot += deg[v];
	if (lim < tot)
		return false;
			
	for (int i = 0, cnt = 0; i < (int) a.size(); i++) {
		int u = a[i];
		int skipped = 0;
		for (int v : adj[u]) if (!flag[v]) {
          	skipped += (dp[v] + tot) <= lim;
          	cnt += (dp[v] + tot > lim);
        }
				
		tot -= skipped;
		
		if (cnt > i + 1)
			return false;
	}
	return true;
}

int main() {
	ios_base::sync_with_stdio(0);
	cin.tie(0);
	#ifdef ONLINE_JUDGE
	freopen("mousetrap.inp","r",stdin);
	freopen("mousetrap.out","w",stdout);
	#endif
	cin >> n >> t >> s;
	for (int i = 1; i < n; i++) {
		int u,v;
		cin >> u >> v;
		adj[u].push_back(v);
		adj[v].push_back(u);
	}
	
	dfs(s);
	
	while (t != -1) {
		a.push_back(t);
		t = par[t];
	}
	reverse(ALL(a));
	for (int v : a)
		flag[v] = true;
	a.pop_back();
	
		
	for (int u : a) {
		deg[u] = 0;
		for (int v : adj[u]) if (!flag[v]) {
			solve(v, u);
			deg[u]++;
			// cerr << u << " " << v << "\n";
		}
	}
	
	int lo = -1;
	int hi = n;
	while (hi - lo > 1) {
		int mid = (lo + hi) >> 1;
		if (check(mid))
			hi = mid;
		else
			lo = mid;
	}
	
	cout << hi;
	return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...