Submission #290588

#TimeUsernameProblemLanguageResultExecution timeMemory
290588ChrisTHighway Tolls (IOI18_highway)C++17
51 / 100
673 ms29416 KiB
#include <bits/stdc++.h>
#include "highway.h"
using namespace std;
const int MN = 9e4 + 5;
vector<pair<int,int>> adj[MN];
vector<pair<int,int>> atDep[MN];
pair<int,int> par[MN]; vector<int> w;
int depth[MN], mxDep;
void dfs (int cur, int p = -1) {
	if (~p) atDep[depth[cur]].push_back({cur,par[cur].second});
	mxDep = max(mxDep,depth[cur]);
	for (auto [i,j] : adj[cur]) if (i != p) {
		par[i] = {cur,j}; depth[i] = depth[cur] + 1;
		dfs(i,cur);
	}
}
void find_pair (int n, vector<int> u, vector<int> v, int a, int b) {
	assert((int)u.size() == n - 1);
	for (int i = 0; i + 1 < n; i++) {
		adj[++u[i]].emplace_back(++v[i],i);
		adj[v[i]].emplace_back(u[i],i);
	}
	w.resize(n-1);
	long long smallDist = ask(w);
	dfs(1);
	auto find = [&] (set<int> banned) {
		int low = 1, high = mxDep, mid, ans = -1;
		while (low <= high) {
			mid = (low + high) / 2;
			for (int i = mid; i <= mxDep; i++) for (auto p : atDep[i]) if (!banned.count(p.second)) w[p.second] = 1;
			if (ask(w) != smallDist) low = (ans = mid) + 1;
			else high = mid - 1;
			for (int i = mid; i <= mxDep; i++) for (auto p : atDep[i]) if (!banned.count(p.second))w[p.second] = 0;
		}
		if (!~ans) return -1;
		function<int(int,int)> get = [&] (int l, int r) {
			if (l == r) return atDep[ans][l].first;
			int mid = (l + r) / 2;
			for (int i = l; i <= mid; i++) if (!banned.count(atDep[ans][i].second)) w[atDep[ans][i].second] = 1;
			long long got = ask(w);
			for (int i = l; i <= mid; i++) if (!banned.count(atDep[ans][i].second)) w[atDep[ans][i].second] = 0;
			if (got != smallDist) return get(l,mid);
			return get(mid+1,r);
		};
		return get(0,(int)atDep[ans].size() - 1);
	};
	int s = find({});
	set<int> ban; int cur = s;
	while (cur != 1) {
		ban.insert(par[cur].second);
		cur = par[cur].first;
	}
	int t = find(ban);
	if (!~t) { //answer is on root-->s path
		vector<pair<int,int>> go;
		cur = s;
		while (cur != 1) {
			go.push_back(par[cur]);
			cur = par[cur].first;
		}
		int low = 0, high = (int)go.size() - 1,mid,ans=-1;
		while (low <= high) {
			mid = (low + high) / 2;
			w[go[mid].second] = 1;
			if (ask(w) != smallDist) low = (ans = mid) + 1;
			else high = mid - 1;
			w[go[mid].second] = 0;
		}
		assert(~ans);
		t = go[ans].first;
	}
	answer(--s,--t);
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...