제출 #804729

#제출 시각아이디문제언어결과실행 시간메모리
804729IvanJHighway Tolls (IOI18_highway)C++17
100 / 100
232 ms15388 KiB
#include "highway.h"
#include<bits/stdc++.h>

#define pb push_back
#define x first
#define y second
#define all(a) (a).begin(), (a).end()

using namespace std;

typedef long long ll;
typedef pair<int, int> ii;

const int maxn = 2e5 + 5;

int n, m, len, A, B, pos;
vector<ii> E;
int dist[2][maxn];
int vis[maxn];
vector<ii> adj[maxn];

int get_len() {
	vector<int> w(m, 0);
	ll l = ask(w);
	return (l / A);
}

void bfs(int x, int f) {
	for(int i = 0;i < n;i++) 
		dist[f][i] = 1e9;
	dist[f][x] = 0;
	queue<int> q;
	q.push(x);
	while(q.size()) {
		x = q.front();q.pop();
		for(ii p : adj[x]) {
			if(dist[f][p.x] != 1e9) continue;
			dist[f][p.x] = dist[f][x] + 1;
			q.push(p.x);
		}
	}
}

vector<int> get_tree(int X, int f) {
	queue<int> q;
	q.push(X), vis[X] = 1;
	vector<int> ret;
	while(q.size()) {
		int x = q.front();q.pop();
		for(ii p : adj[x]) {
			if(dist[f][p.x] >= dist[!f][p.x]) continue;
			if(vis[p.x]) continue;
			vis[p.x] = 1;
			ret.pb(p.y), q.push(p.x);
		}
	}
	return ret;
}

int get_first() {
	int lo = 0, hi = m - 1, ans = -1;
	while(lo <= hi) {
		int mid = (lo + hi) / 2;
		vector<int> w(m, 0);
		for(int i = 0;i <= mid;i++) w[i] = 1;
		ll ret = ask(w);
		if(ret == (ll)len * A) lo = mid + 1;
		else ans = mid, hi = mid - 1;
	}
	return ans;
}

int get_last(vector<int> a, vector<int> b) {
	if(a.size() == 0) return -1;
	int lo = -1, hi = (int)a.size() - 1, ans = -1;
	while(lo <= hi) {
		int mid = (lo + hi) / 2;
		vector<int> w(m, 1);
		for(int i = 0;i <= mid;i++) w[a[i]] = 0;
		for(int i : b) w[i] = 0;
		w[pos] = 0;
		ll ret = ask(w);
		if(ret == (ll)len * A) 
			ans = mid, hi = mid - 1;
		else lo = mid + 1;
	}
	if(ans == -1) return -1;
	return a[ans];
}

void find_pair(int N, vector<int> U, vector<int> V, int A, int B) {
	::A = A, ::B = B;
	n = N, m = (int)U.size();
	for(int i = 0;i < m;i++) {
		E.pb({U[i], V[i]});
		adj[U[i]].pb({V[i], i}), adj[V[i]].pb({U[i], i});
	}
	len = get_len();
	
	pos = get_first();
	ii e = E[pos];
	//cout << e.x << " " << e.y << " first\n";
	bfs(e.x, 0);
	bfs(e.y, 1);
	
	vector<int> a = get_tree(e.x, 0);
	vector<int> b = get_tree(e.y, 1);
	//for(int i : a) cout << E[i].x << " " << E[i].y << " a\n";
	//for(int i : b) cout << E[i].x << " " << E[i].y << " b\n";
	int x = get_last(a, b);
	int y = get_last(b, a);
	//cout << x << " " << y << "\n";
	
	int s, t;
	if(x == -1) s = e.x;
	else {
		if(dist[0][E[x].x] > dist[0][E[x].y])
			s = E[x].x;
		else s = E[x].y;
	}
	
	if(y == -1) t = e.y;
	else {
		if(dist[1][E[y].x] > dist[1][E[y].y])
			t = E[y].x;
		else t = E[y].y;
	}
	answer(s, t);
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...