Submission #297615

#TimeUsernameProblemLanguageResultExecution timeMemory
297615rqiHighway Tolls (IOI18_highway)C++14
51 / 100
336 ms31740 KiB
#include "highway.h"
#include <bits/stdc++.h>
using namespace std;

typedef vector<int> vi;
typedef long long ll;
typedef pair<int, int> pi;
typedef vector<pi> vpi;

#define pb push_back
#define mp make_pair
#define f first
#define s second
#define sz(x) (int)(x).size()
#define ins insert

const int mx = 90005;
int N, M;
ll A, B;
int D;
vi U, V;
vpi adj[mx]; //node, edge

int getDist(){
	vi w(M, 0);
	return ask(w)/A;
}

vpi jeds; //edges just before degree d nodes, label of nodes
int Rdist;

void findJeds(int node, int d = 0, int prv = -1){
	for(auto u: adj[node]){
		if(u.f == prv) continue;
		if(d+1 == Rdist){
			jeds.pb(mp(u.s, u.f));
			continue;
		}
		findJeds(u.f, d+1, node);
	}
}

int treeSolveEz(vector<pair<pi, int>> eds, int R, int d){ //Rdist = d
	Rdist = d;
	for(auto u: eds){
		adj[u.f.f].clear(); adj[u.f.s].clear();
	}
	for(auto u: eds){
		adj[u.f.f].pb(mp(u.f.s, u.s)); adj[u.f.s].pb(mp(u.f.f, u.s));
	}

	jeds.clear();

	findJeds(R);

	while(sz(jeds) > 1){
		vpi njeds;
		vpi ojeds;
		for(int i = 0; i < sz(jeds)/2; i++){
			njeds.pb(jeds[i]);
		}
		for(int i = sz(jeds)/2; i < sz(jeds); i++){
			ojeds.pb(jeds[i]);
		}
		vi w(M, 0);
		for(auto u: njeds){
			w[u.f] = 1;
		}
		if(ask(w) == A*D){
			jeds = ojeds;
		}
		else jeds = njeds;
	}
	return jeds[0].s;
}

int n;

int sub[mx];

void genSub(int node, int prv = -1){
	sub[node] = 1;
	for(auto u: adj[node]){
		if(u.f == prv) continue;
		genSub(u.f, node);
		sub[node]+=sub[u.f];
	}
}

int findCen(int node, int prv = -1){
	for(auto u: adj[node]){
		if(u.f == prv) continue;
		if(sub[u.f] >= n-sub[u.f]) return findCen(u.f, node);
	}
	return node;
}

int und[mx];

void genUnd(int node, int lab, int prv){
	und[node] = lab;
	for(auto u: adj[node]){
		if(u.f == prv) continue;
		genUnd(u.f, lab, node);
	}
}

bool inSet[mx];

pi treeSolve(vector<pair<pi, int>> eds){
	// cout << "eds: \n";
	// for(auto u: eds){
	// 	cout << u.f.f << " " << u.f.s << " " << u.s << "\n";
	// }
	n = sz(eds)+1;
	if(n == 2) return eds[0].f;

	set<int> trash;
	for(auto u: eds){
		trash.ins(u.f.f);
		trash.ins(u.f.s);
	}

	vi nodes;
	for(auto u: trash) nodes.pb(u);

	for(auto u: eds){
		adj[u.f.f].clear(); adj[u.f.s].clear();
	}
	for(auto u: eds){
		adj[u.f.f].pb(mp(u.f.s, u.s)); adj[u.f.s].pb(mp(u.f.f, u.s));
	}
	int R = eds[0].f.f;
	genSub(R);
	R = findCen(R);
	//cout << "cen: " << R << "\n";

	genSub(R);
	und[R] = -1;
	for(auto u: adj[R]){
		genUnd(u.f, u.f, R);
		inSet[u.f] = 0;
	}
	assert(sz(adj[R]) >= 2); //at least subtrees
	pair<int, vi> split = mp(0, vi{});
	for(auto u: adj[R]){
		int newval = split.f+sub[u.f];
		if(min(newval, n-1-newval) >= min(split.f, n-1-split.f)){
			split.f = newval;
			split.s.pb(u.f);
			inSet[u.f] = 1;
		}
	}

	// cout << "split: ";
	// cout << split.f << ", {";
	// for(auto u: split.s) cout << u << " ";
	// cout << "}\n";

	//also  consider single splits?

	vi w(M, 0);
	vector<pair<pi, int>> neds;
	vector<pair<pi, int>> oeds;

	for(auto u: eds){
		int a = u.f.f;
		if(a == R) a = u.f.s;
		if(inSet[und[a]]){
			w[u.s] = 1;
			//cout << "EDGE INSET " << u.s << "\n";
			neds.pb(u);
		}
		else oeds.pb(u);
	}

	ll res = ask(w);
	// cout << A << " " << B << " " << D << "\n";
	// cout << "res: " << res << "\n";
	if(res == A*D){
		return treeSolve(oeds);
	}
	else if(res == B*D){
		return treeSolve(neds);
	}

	int inneds = (res-A*D)/(B-A);
	pi ans;
	ans.f = treeSolveEz(neds, R, inneds);
	ans.s = treeSolveEz(oeds, R, D-inneds);
	return ans;
}

void find_pair(int _N, vi _U, vi _V, int _A, int _B) {
	N = _N;
	U = _U;
	V = _V;
	M = sz(U);
	for(int i = 0; i < M; i++){
		adj[U[i]].pb(mp(V[i], i));
		adj[V[i]].pb(mp(U[i], i));
	}
	A = _A;
	B = _B;


	if(M == N-1){ //TREE CASE
		// bool isLine = 1;
		// for(int i = 0; i < M; i++){
		// 	if(U[i] != i || V[i] != i+1) isLine = 0;
		// }
		// if(isLine){
		// 	D = getDist();
		// 	vi sts;
		// 	for(int i = 0; i < N; i++){
		// 		if(i+D < N) sts.pb(i);
		// 	}

		// 	while(sz(sts) > 1){
		// 		vi nsts;
		// 		vi osts;
		// 		for(int i = 0; i < sz(sts)/2; i++){
		// 			nsts.pb(sts[i]);
		// 		}
		// 		for(int i = sz(sts)/2; i < sz(sts); i++){
		// 			osts.pb(sts[i]);
		// 		}
		// 		vi w(M, 0);
		// 		for(auto u: nsts) w[u] = 1;
		// 		if(ask(w) == A*D){
		// 			sts = osts;
		// 		}
		// 		else sts = nsts;
		// 	}
		// 	answer(sts[0], sts[0]+D);
		// 	return;
		// }
		D = getDist();
		vector<pair<pi, int>> eds;
		for(int i = 0; i < M; i++){
			eds.pb(mp(mp(U[i], V[i]), i));
		}

		pi ans = treeSolve(eds);
		answer(ans.f, ans.s);

		return;

		//0 is one of the parents
		D = getDist();
		findJeds(0, 0);
		while(sz(jeds) > 1){
			vpi njeds;
			vpi ojeds;
			for(int i = 0; i < sz(jeds)/2; i++){
				njeds.pb(jeds[i]);
			}
			for(int i = sz(jeds)/2; i < sz(jeds); i++){
				ojeds.pb(jeds[i]);
			}
			vi w(M, 0);
			for(auto u: njeds){
				w[u.f] = 1;
			}
			if(ask(w) == A*D){
				jeds = ojeds;
			}
			else jeds = njeds;
		}
		answer(0, jeds[0].s);
		return;
	}
	

}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...