Submission #1226640

#TimeUsernameProblemLanguageResultExecution timeMemory
1226640colossal_pepeHighway Tolls (IOI18_highway)C++20
51 / 100
138 ms9928 KiB
#include "highway.h"
#include <bits/stdc++.h>
using namespace std;

/*
1. root the tree at 0
2. binary search on depth and find max(dep(S), dep(T)). WLOG let argmax(dep(S), dep(T)) = S
3. binary search on all nodes with depth = dep(S) and find S
4. root tree at S
5. binary search on depth for dep(T)
6. binary search on all nodes with depth = dep(T) and find T
*/

using ll = long long;

int n, root;
ll light, heavy, all_light, all_heavy;
vector<pair<int, int>> edges;
vector<vector<pair<int, int>>> t;

int findMaxDepth() {
	vector<int> w(n - 1);
	int l = 0, r = n - 1, mid;
	auto setWeights = [&w, &mid](const auto &self, int u, int par, int depth) -> void {
		for (auto [i, v] : t[u]) {
			if (v == par) continue;
			w[i] = (depth + 1 <= mid);
			self(self, v, u, depth + 1);
		}
	};
	while (r - l > 1) {
		mid = (l + r) / 2;
		setWeights(setWeights, root, root, 0);
		ll cost = ask(w);
		if (cost == all_heavy) r = mid;
		else l = mid + 1;
	}
	mid = l;
	setWeights(setWeights, root, root, 0);
	return (ask(w) == all_heavy ? l : r);
}

int findOfDepth(int target_depth) {
	vector<pair<int, int>> s;
	auto dfs = [&](const auto &self, int u, int par, int depth) -> void {
		for (auto [i, v] : t[u]) {
			if (v == par) continue;
			if (depth + 1 == target_depth) s.push_back(make_pair(i, v));
			self(self, v, u, depth + 1);
		}
	};
	dfs(dfs, root, root, 0);
	int l = 0, r = s.size() - 1, mid;
	vector<int> w(n - 1, 0);
	auto setWeights = [&s, &w, &l, &mid]() -> void {
		for (int i = l; i <= mid; i++) {
			w[s[i].first] = 1;
		}
	};
	auto undoWeights = [&s, &w, &l, &mid]() -> void {
		for (int i = l; i <= mid; i++) {
			w[s[i].first] = 0;
		}
	};
	while (r - l >  1) {
		mid = (l + r) / 2;
		setWeights();
		ll cost = ask(w);
		if (cost != all_light) r = mid;
		else l = mid + 1;
		undoWeights();
	}
	mid = l;
	setWeights();
	return (ask(w) != all_light ? s[l].second : s[r].second);
}

void find_pair(int N, vector<int> U, vector<int> V, int A, int B) {
	if (U.size() != N - 1) answer(0, 1);
	n = N;
	light = A, heavy = B;
	t.resize(n);
	for (int i = 0; i < n - 1; i++) {
		t[U[i]].push_back(make_pair(i, V[i]));
		t[V[i]].push_back(make_pair(i, U[i]));
	}
	all_light = ask(vector<int>(n - 1, 0));
	all_heavy = (all_light / light) * heavy;

	cerr << all_light << ' ' << all_heavy << endl;
	root = 0;
	int S = findOfDepth(findMaxDepth());
	root = S;
	int T = findOfDepth(all_light / light);
	answer(S, T);
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...