제출 #594803

#제출 시각아이디문제언어결과실행 시간메모리
594803Temmie통행료 (IOI18_highway)C++17
5 / 100
173 ms10080 KiB
#include "highway.h"
#include <bits/stdc++.h>

int n, m;
std::vector <std::vector <int>> g;

long long dist;

void find_pair(int _n, std::vector <int> _u, std::vector <int> _v, int a, int b) {
	n = _n;
	m = _u.size();
	g.resize(n);
	for (int i = 0; i < m; i++) {
		g[_u[i]].push_back(i);
		g[_v[i]].push_back(i);
	}
	dist = ask(std::vector <int> (m, 0)) / a;
	int root[2] = { -1, -1 };
	{
		std::vector <int> cur(m, 1);
		int left = m;
		while (left > 1) {
			auto cop = cur;
			int mid = left >> 1;
			for (int i = 0; i < m && left > mid; i++) {
				if (cur[i]) {
					left--;
					cur[i] = 0;
				}
			}
			long long val = ask(cur);
			if (val > dist * a) {
				left = 0;
				for (int x : cur) {
					left += x;
				}
			} else {
				left = 0;
				for (int i = 0; i < m; i++) {
					if (cur[i] && cop[i]) {
						cop[i] = 0;
					}
					left += cop[i];
				}
				cur = cop;
			}
		}
		assert(left == 1);
		for (int i = 0; i < m; i++) {
			if (cur[i]) {
				root[0] = _u[i];
				root[1] = _v[i];
				break;
			}
		}
	}
	assert(root[0] != -1 && root[1] != -1);
	std::vector <int> nod[2];
	std::vector <int> par(n, -1);
	{
		std::queue <std::pair <int, std::pair <int, int>>> q;
		q.push({ root[0], { -1, root[0] } });
		q.push({ root[1], { -1, root[1] } });
		std::vector <bool> vis(n, false);
		while (q.size()) {
			int rt = q.front().first;
			int pa = q.front().second.first;
			int v = q.front().second.second;
			q.pop();
			if (vis[v]) {
				continue;
			}
			vis[v] = true;
			par[v] = pa;
			nod[rt == root[0] ? 0 : 1].push_back(v);
			for (int x : g[v]) {
				int to = v ^ _u[x] ^ _v[x];
				q.push({ rt, { x, to } });
			}
		}
	}
	int ans[2];
	for (int i = 0, uwd; i <= 1; i++) {
		if (!i) {
			std::vector <int> cur(m, 1);
			for (int x : nod[i]) {
				if (par[x] != -1) {
					cur[par[x]] = 0;
				}
			}
			long long val = ask(cur);
			val -= dist * a;
			uwd = dist - val / (b - a);
		} else {
			uwd = dist - uwd - 1;
		}
		int l = 1, r = (int) nod[i].size() - 1, best = nod[i][0];
		while (l <= r) {
			int mid = (l + r) >> 1;
			std::vector <int> cur(m, 1);
			for (int j = 1; j < mid; j++) {
				cur[par[nod[i][j]]] = 0;
			}
			long long val = ask(cur);
			if (val > a * uwd + b * (dist - uwd)) {
				best = nod[i][mid];
				l = mid + 1;
			} else {
				r = mid - 1;
			}
		}
		ans[i] = best;
	}
	answer(ans[0], ans[1]);
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...