Submission #296975

#TimeUsernameProblemLanguageResultExecution timeMemory
296975RealSuperman1Highway Tolls (IOI18_highway)C++14
51 / 100
200 ms20800 KiB
#pragma GCC optimize("Ofast")
 
#include <bits/stdc++.h>
#include "highway.h"
#define ll long long
#define fi first
#define se second
#define pb push_back
#define pll pair<long long, long long>
#define pii pair<int, int>
 
using namespace std;
 
//const int N = 2e5 + 10;
 
int n, m, max_height = 0, timer;
vector < vector <pii> > g;
vector < vector <int> > fixed_h;
vector <int> h, tin, tout;
vector <pii> par;
 
void dfs(int u, int p, int id_w, int height) {
	tin[u] = ++timer;
	max_height = max(max_height, height);
	h[u] = height;
	par[u] = {p, id_w};
	for (auto to : g[u])
		if (to.fi != p)
			dfs(to.fi, u, to.se, h[u] + 1);
	tout[u] = ++timer;
}
 
bool parent(int par, int u) {
	return (tin[par] <= tin[u] && tout[par] >= tout[u]);
}
 
ll find_val() {
	vector <int> w(m, 0);
	return ask(w);
}
 
int find_lower_level(ll val) {
	int L = 0, R = max_height, M, ans = 0;
	vector <int> w;
	while (L <= R) {
		M = (L + R) / 2;
		w.assign(m, 0);
		for (int height = M; height <= R; height++)
			for (int to : fixed_h[height]) {
				if (par[to].se >= 0)
					w[par[to].se] = 1;
			}
		ll curval = ask(w);
		if (curval > val) {
			L = M + 1;
			ans = max(ans, M);
		} else {
			R = M - 1;
		}
	}
	return ans;
}
 
int get_answer(vector <int> candidates, ll val) {
	vector <int> L, R, w;
	while (candidates.size() > 1) {
		L.clear(); R.clear();
		w.assign(m, 0);
		int lim = candidates.size() / 2;
		for (int i = 0; i < lim; i++) {
			int u = candidates[i];
			L.pb(u);
			if (par[u].se >= 0)
				w[par[u].se] = 1;
		}
		for (int i = lim; i < candidates.size(); i++) {
			R.pb(candidates[i]);
		}
		ll curval = ask(w);
		if (curval > val)
			candidates = L;
		else
			candidates = R;
	}
	return candidates[0];
}
 
int solve_from(int C, int dist, ll val) {
	max_height = 0; timer = 0;
	dfs(C, -1, -1, 0);
	fixed_h.resize(max_height + 1);
	for (int i = 0; i <= max_height; i++)
		fixed_h[i].clear();
	for (int i = 0; i < n; i++)
		fixed_h[h[i]].pb(i);
	return get_answer(fixed_h[dist], val);
}

void case1234(vector <int> U, vector <int> V, int A, int B) {
	g.resize(n);
	h.resize(n);
	par.resize(n);
	tin.resize(n);
	tout.resize(n);
	for (int i = 0; i < m; i++) {
		g[U[i]].pb({V[i], i});
		g[V[i]].pb({U[i], i});
	}
	timer = 0;
	dfs(0, -1, -1, 0);
	fixed_h.resize(max_height + 1);
	for (int i = 0; i < n; i++)
		fixed_h[h[i]].pb(i);
	ll val = find_val();
	int dist = val / (A * 1ll);
	int lower_level = find_lower_level(val);
	int C = get_answer(fixed_h[lower_level], val);
	int T = solve_from(C, dist, val);
	answer(C, T);
}

vector <pair<pii, int > > edges;

vector <int> divide(vector <int> &s) {
	vector <int> col(n);
	for (int i = 0; i < n; i++)
		col[i] = 0;
	for (int to : s)
		col[to] = 1;
	vector <int> w(m);
	for (auto to : edges)
		if (col[to.fi.fi] == col[to.fi.se])
			w[to.se] = 1;
		else
			w[to.se] = 0;
	return w;
}

vector <int> intersect(vector <int> &x, vector <int> &y) {
	vector <int> cnt(n, 0);
	for (int to : x)
		cnt[to]++;
	for (int to : y)
		cnt[to]++;
	vector <int> z = {};
	for (int i = 0; i < n; i++)
		if (cnt[i] == 2)
			z.pb(i);
	return z;
}

void case5(vector <int> U, vector <int> V, int A, int B) {
	edges.resize(m);
	for (int i = 0; i < m; i++) {
		edges[i] = {{U[i], V[i]}, i};
	}
	vector < vector <int> > bit[2];
	bit[0].resize(16);
	bit[1].resize(16);
	for (int i = 0; i < n; i++)
		for (int j = 0; j < 16; j++)
			if (1 & (i >> j))
				bit[1][j].pb(i);
			else
				bit[0][j].pb(i);
//	for (int i = 0; i < 4; i++) {
//		cout << "bit " << i << endl;
//		cout << "bit0 ";
//		for (int to : bit[0][i])
//			cout << to << " ";
//		cout << endl << "bit1 ";
//		for (int to : bit[1][i])
//			cout << to << " ";
//		cout << endl;
//	}
//	return;
	int xr = 0;
	for (int i = 0; i < 16; i++) {
		vector <int> w = divide(bit[0][i]);
		int ans = ask(w);
		if (ans % 2)
			xr ^= (1 << i);
	}
//	cout << "xor " << xr << endl;
	int x, s = 0, t = 0;
	for (int i = 0; i < 16; i++)
		if (1 & (xr >> i)) {
			x = i;
			break;
		}
	t ^= (1 << x);
	for (int i = 0; i < 16; i++) {
		if (i == x)
			continue;
		if (1 & (xr >> i)) {
			vector <int> z = intersect(bit[0][i], bit[0][x]);
			vector <int> w = divide(z);
			int ans = ask(w);
			if (ans % 2 == 1)
				t ^= (1 << i);
			else
				s ^= (1 << i);
		} else {
			vector <int> z = intersect(bit[0][i], bit[0][x]);
			vector <int> w = divide(z);
			int ans = ask(w);
			if (ans % 2 == 0) {
				t ^= (1 << i);
				s ^= (1 << i);
			}
		}
	}
//	cout << s << " " << t << endl;
	answer(s, t);
}

void find_pair(int n1, vector <int> U, vector <int> V, int A, int B) {
	n = n1;
	m = U.size();
	if (m == n - 1) {
		case1234(U, V, A, B);
		return;
	} else if (A == 1 && B == 2) {
		case5(U, V, A, B);
		return;
	}
}

Compilation message (stderr)

highway.cpp: In function 'int get_answer(std::vector<int>, long long int)':
highway.cpp:76:23: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
   76 |   for (int i = lim; i < candidates.size(); i++) {
      |                     ~~^~~~~~~~~~~~~~~~~~~
highway.cpp: In function 'void case5(std::vector<int>, std::vector<int>, int, int)':
highway.cpp:193:3: warning: 'x' may be used uninitialized in this function [-Wmaybe-uninitialized]
  193 |   if (i == x)
      |   ^~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...