제출 #594048

#제출 시각아이디문제언어결과실행 시간메모리
594048Temmie통행료 (IOI18_highway)C++17
6 / 100
122 ms12440 KiB
#include "highway.h"
#include <bits/stdc++.h>

int n, m;
std::vector <std::vector <int>> g;

long long dist;

void find_pair(int _n, std::vector <int> _u, std::vector <int> _v, int a, int b) {
	n = _n;
	m = _u.size();
	g.resize(n);
	for (int i = 0; i < m; i++) {
		g[_u[i]].push_back(i);
		g[_v[i]].push_back(i);
	}
	dist = ask(std::vector <int> (m, 0)) / a;
	
	int onw = -1, onwe = -1;
	{
		std::vector <int> cur(m, 1);
		int left = m;
		while (left > 1) {
			int mid = left / 2;
			auto cop = cur;
			for (int i = 0; i < m && left > mid; i++) {
				if (cur[i]) {
					cur[i] = 0;
					left--;
				}
			}
			long long val = ask(cur);
			if (val >= (dist - 1) * a + b) {
				left = 0;
				for (int x : cur) {
					left += x;
				}
			} else {
				left = 0;
				for (int i = 0; i < m; i++) {
					if (cop[i] && !cur[i]) {
						cop[i] = 1;
					} else if (cop[i] && cur[i]) {
						cop[i] = 0;
					}
					left += cop[i];
				}
				cur = cop;
			}
		}
		assert(left == 1);
		for (int i = 0; i < m; i++) {
			if (cur[i]) {
				onw = _u[onwe = i];
				break;
			}
		}
		assert(onw != -1);
	}
	int dtf = -1;
	std::vector <int> par(n, -1);
	std::vector <std::vector <int>> ofd(n + 1);
	{
		int half = dist / 2;
		std::vector <int> cur(m, 1);
		std::queue <std::pair <int, std::pair <int, int>>> q;
		std::vector <bool> vis(n, false);
		q.push({ -1, { onw, 0 } });
		while (q.size()) {
			int v = q.front().second.first;
			int w = q.front().second.second;
			int pa = q.front().first;
			q.pop();
			if (vis[v]) {
				continue;
			}
			vis[v] = true;
			par[v] = pa;
			ofd[w].push_back(v);
			if (w <= half) {
				cur[pa] = 0;
			}
			for (int x : g[v]) {
				int to = _u[x] ^ _v[x] ^ v;
				q.push({ x, { to, w + 1 } });
			}
		}
		long long val = ask(cur);
		val -= dist * a;
		int heavies = val / (b - a);
		dtf = half + heavies;
	}
	if (dtf == dist / 2) {
		dtf++;
		onw = _v[onwe];
		par = std::vector <int> (n, -1);
		ofd = std::vector <std::vector <int>> (n + 1);
		std::queue <std::pair <int, std::pair <int, int>>> q;
		std::vector <bool> vis(n, false);
		q.push({ -1, { onw, 0 } });
		while (q.size()) {
			int v = q.front().second.first;
			int w = q.front().second.second;
			int pa = q.front().first;
			q.pop();
			if (vis[v]) {
				continue;
			}
			vis[v] = true;
			par[v] = pa;
			ofd[w].push_back(v);
			for (int x : g[v]) {
				int to = _u[x] ^ _v[x] ^ v;
				q.push({ x, { to, w + 1 } });
			}
		}
	}
	int s = -1;
	{
		std::vector <int> cur(m, 0);
		int left = ofd[dtf].size();
		while (left > 1) {
			int mid = left / 2;
			auto cop = cur;
			for (int i = 0; i < (int) ofd[dtf].size() && left > mid; i++) {
				if (!cur[ofd[dtf][i]]) {
					cur[ofd[dtf][i]] = 1;
					left--;
				}
			}
			long long val = ask(cur);
			if (val == dist * a) {
				left = 0;
				for (int x : ofd[dtf]) {
					left += !cur[x];
				}
			} else {
				left = 0;
				for (int i = 0; i < (int) ofd[dtf].size(); i++) {
					if (!cur[ofd[dtf][i]]) {
						cop[ofd[dtf][i]] = 1;
					}
					left += !cop[ofd[dtf][i]];
				}
				cur = cop;
			}
		}
		assert(left == 1);
		for (int x : ofd[dtf]) {
			if (!cur[par[x]]) {
				s = x;
				break;
			}
		}
		assert(s != -1);
	}
	
	std::vector <int> cand;
	{
		std::vector <int> ok(n, 0);
		for (int x : ofd[dist - dtf]) {
			ok[x]++;
		}
		std::queue <std::pair <int, int>> q;
		q.push({ s, 0 });
		std::vector <bool> vis(n, false);
		while (q.size()) {
			int v = q.front().first;
			int w = q.front().second;
			q.pop();
			if (vis[v]) {
				continue;
			}
			vis[v] = true;
			ok[v] += w == dist;
			for (int x : g[v]) {
				int to = _u[x] ^ _v[x] ^ v;
				q.push({ to, w + 1 });
			}
		}
		for (int i = 0; i < n; i++) {
			if (ok[i] == 2) {
				cand.push_back(i);
			}
		}
		assert(cand.size());
	}
	int e = -1;
	{
		std::vector <int> cur(m, 0);
		int left = cand.size();
		while (left > 1) {
			int mid = left / 2;
			auto cop = cur;
			for (int i = 0; i < (int) cand.size() && left > mid; i++) {
				if (!cur[par[cand[i]]]) {
					cur[par[cand[i]]] = 1;
					left--;
				}
			}
			long long val = ask(cur);
			if (val == dist * a) {
				left = 0;
				for (int x : cand) {
					left += !cur[x];
				}
			} else {
				left = 0;
				for (int i = 0; i < (int) cand.size(); i++) {
					if (!cur[par[cand[i]]]) {
						cop[par[cand[i]]] = 1;
					}
					left += !cop[par[cand[i]]];
				}
				cur = cop;
			}
		}
		assert(left == 1);
		for (int x : cand) {
			if (!cur[par[x]]) {
				e = x;
				break;
			}
		}
		assert(e != -1);
	}
	answer(s, e);
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...