Submission #1364817

#TimeUsernameProblemLanguageResultExecution timeMemory
1364817ByeWorldHighway Tolls (IOI18_highway)C++20
51 / 100
203 ms327680 KiB
#include "highway.h"
#include <bits/stdc++.h>
#pragma GCC optimize("O3", "Ofast")
#define ll long long
#define se second
#define fi first
#define pb push_back
#define lf (id<<1)
#define rg ((id<<1)|1)
#define md ((l+r)>>1)
using namespace std;
typedef pair<int,int> pii;
typedef pair<pii,pii> ipii;
const int MAXN = 6e5+10;
const int MAXA = 5e4+10;
const int SQRT = 300;
const int INF = 2e9;
const int MOD = 1e9+87;
const int MOD2 = 1e9+7;
const int LOG = 30;

int n, a, b, m, x, y, val[MAXN], dep[MAXN];
vector<pii> adj[MAXN];
vector<int> vec; // urutan

int MN;
void dfs2(int nw, int pa){
	if(nw!=1) dep[nw] = dep[pa]+1;
	if(dep[nw] >= MN+1) vec.pb(nw);

	for(int i=adj[nw].size()-1; i>=0; i--){
		auto [nx, wei] = adj[nw][i];
		if(nx==pa) continue;
		val[nx] = wei; 
		dfs2(nx, nw);
	}
}
int par[MAXN];
void dfs(int nw, int pa){
	par[nw] = pa;
	if(nw!=1) dep[nw] = dep[pa]+1;
	if(dep[nw] >= MN+1) vec.pb(nw);

	for(auto [nx, wei] : adj[nw]){
		if(nx==pa) continue;
		val[nx] = wei; 
		dfs(nx, nw);
	}
}

ll DIS, MX;
ll ceksuff(int mid){
	vector<int> w(m, 0);
	for(int i=0; i<=mid; i++){
		int idx = vec[i];
		w[val[idx]] = 1;
	}
	return ask(w);
}
ll cek(int mid){
	vector<int> w(m, 0);
	for(int i=0; i<=mid; i++){
		int idx = vec[i];
		w[val[idx]] = 1;
	}
	return ask(w);
}
int lca = 1;

int DEP;
vector<int> w;

void build(int nw, int pa){
	for(auto [nx, wei] : adj[nw]){
		if(nx == pa) continue;
		dep[nx] = dep[nw]+1;
		if(dep[nx] <= DEP) w[wei] = 1;

		build(nx, nw);
	}
}
int mx;
void find(){
	int l=1, r=mx, mid=0, cnt=0;
	while(l<=r){
		mid = md; 
		// mid = 1; 
		DEP = mid;
		w.clear(); w.resize(m, 0);
		build(1, 0);

		// cout << mid << ' '<< ask(w) << ' ' << DIS*a << " ask\n";
		if(ask(w) == 1ll*DIS*a){
			// cout << mid << " mid\n";
			cnt = mid, l = mid+1;
		} else r = mid-1;
		// break;
	}
	// depnya di cnt
	MN = cnt;
}
void find_pair(int N, std::vector<int> U, std::vector<int> V, int A, int B) {
	n = N; a = A; b = B; m = U.size();
	for(int i=0; i<m; i++){
		adj[U[i]+1].pb({V[i]+1, i});
		adj[V[i]+1].pb({U[i]+1, i});
	}
	dfs(1,0);
	for(int i=1; i<=n; i++) mx = max(mx, dep[i]);

	vector<int> w(m, 0);
	DIS = ask(w) / a;

	find();
	// cout << MN << " Mn\n"; // lcanya di sini

	MX = cek(vec.size()-1); // kalo semua jadi b
	vec.clear();
	dfs(1, 0);
	{
		// cout << mx << " mx\n";
		int l=0, r=vec.size()-1, mid=0, cnt=-1;
		while(l<=r){
			mid = md;
			if(cek(mid) == MX) cnt = mid, r = mid-1;
			else l = mid+1;
		}
		if(cnt == -1) assert(false);
		// cout << cnt << ' '<< vec[cnt] << ' '<< cek(4)<<' '<<MX << " cnt\n";
		y = vec[cnt];
		// cout << y <<" y\n";
	}
	vec.clear();
	dfs2(1,0);
	// for(auto in : vec) cout << in << " in\n";
	{
		int l=0, r=vec.size()-1, mid=0, cnt=-1;
		while(l<=r){
			mid = md;
			if(ceksuff(mid) == MX) cnt = mid, r = mid-1;
			else l = mid+1;
		}
		if(cnt == -1) assert(false);
		// cout << cnt << ' '<< vec[cnt] << ' '<< ceksuff(4) << " cnt\n";
		x = vec[cnt];
		// cout << x <<" x\n";
	}
	// cout << x << ' '<< y << "Xy\n";
	if(x==y){
		// cout << MN << " mn\n";
		while(dep[x] != MN) x = par[x];
	}
	// cout << x << ' '<< y << "Xy\n";

	answer(x-1, y-1);
}
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...