Submission #825768

#TimeUsernameProblemLanguageResultExecution timeMemory
825768pawnedSplit the Attractions (IOI19_split)C++17
7 / 100
41 ms10564 KiB
#include <bits/stdc++.h>
using namespace std;

#define fi first
#define se second
#define pb push_back
typedef long long ll;
typedef pair<int, int> ii;
typedef vector<int> vi;

#include "split.h"

int N, M;
int A, B, C;

vi adj0[100005];
vi adj[100005];

bool vis0[100005];

void dfs0(int n) {
	vis0[n] = true;
	for (int i : adj0[n]) {
		if (!vis0[i]) {
			adj[n].pb(i);
			adj[i].pb(n);
			dfs0(i);
		}
	}
}

bool vis[100005];
int parent[100005];
int subtr[100005];

void dfs(int n) {
	vis[n] = true;
	subtr[n] = 1;
	for (int i : adj[n]) {
		if (!vis[i]) {
			parent[i] = n;
			dfs(i);
			subtr[n] += subtr[i];
		}
	}
}

int blacklist, cap;

set<int> passed;

bool vis2[100005];

void dfs2(int n) {
	vis2[n] = true;
	if ((int)(passed.size()) < cap)
		passed.insert(n);
	for (int i : adj[n]) {
		if (!vis2[i] && i != blacklist)
			dfs2(i);
	}
}

pair<set<int>, set<int>> solve(int x, int a, int b) {
//	cout<<"have to solve "<<x<<" "<<a<<" "<<b<<endl;

	blacklist = parent[x];
	cap = a;
	dfs2(x);
	set<int> set1 = passed;

	passed.clear();
	blacklist = x;
	cap = b;
	dfs2(0);
	set<int> set2 = passed;
/*
	cout<<"set1: ";
	for (int i : set1)
		cout<<i<<" ";
	cout<<endl;
	cout<<"set2: ";
	for (int i : set2)
		cout<<i<<" ";
	cout<<endl;
*/
	return {set1, set2};
	// subtree of size a centered at x, find size b from remaining
	// dfs order from x to find first subtr
	// cut x off, then dfs order from 0 to find size b
	// returns two sets of size a and size b
}

vi find_split(int n, int a, int b, int c, vi p, vi q) {
	N = n;
	A = a; B = b; C = c;
	M = p.size();
	for (int i = 0; i < M; i++) {
		adj0[p[i]].pb(q[i]);
		adj0[q[i]].pb(p[i]);
	}
	bool subtask1 = true;
	for (int i = 0; i < N; i++) {
		if (adj0[i].size() > 2)
			subtask1 = false;
	}
// SUBTASK 1 ONLY
	if (subtask1) {
		M = p.size();
		int endpt = 0;
		for (int i = 0; i < N; i++) {
			if (adj0[i].size() < 2)
				endpt = i;
		}
		vi order;
		int prev = -1;
		int curr = endpt;
		for (int i = 0; i < N; i++) {
			order.pb(curr);
			if (adj0[curr][0] == prev) {
				prev = curr;
				curr = adj0[curr][1];
			}
			else {
				prev = curr;
				curr = adj0[curr][0];
			}
		}
	/*
		cout<<"order: ";
		for (int i : order)
			cout<<i<<" ";
		cout<<endl;
	*/
		vi answer(N);
		for (int i = 0; i < N; i++) {
			if (i < A)
				answer[order[i]] = 1;
			else if (A <= i && i < A + B)
				answer[order[i]] = 2;
			else
				answer[order[i]] = 3;
		}
		return answer;
	}
	dfs0(0);	// ONLY get the important edges!

	cout<<"adj: "<<endl;
	for (int i = 0; i < N; i++) {
		cout<<i<<": ";
		for (int j : adj[i])
			cout<<j<<" ";
		cout<<endl;
	}
	cout<<endl;

	parent[0] = -1;
	dfs(0);
/*
	cout<<"parent: ";
	for (int i = 0; i < N; i++) {
		cout<<parent[i]<<" ";
	}
	cout<<endl;
	cout<<"subtree size: ";
	for (int i = 0; i < N; i++) {
		cout<<subtr[i]<<" ";
	}
	cout<<endl;
*/
	bool found = true;
	pair<set<int>, set<int>> ans;
	for (int i = 1; i < N; i++) {	// try subtree of i
		if (subtr[i] >= a) {
			if (N - subtr[i] >= min(b, c)) {
				found = true;
				ans = solve(i, a, min(b, c));
				break;
			}
		}
		if (subtr[i] >= b) {
			if (N - subtr[i] >= min(a, c)) {
				found = true;
				ans = solve(i, b, min(a, c));
				break;
			}
		}
		if (subtr[i] >= c) {
			if (N - subtr[i] >= min(a, b)) {
				found = true;
				ans = solve(i, c, min(a, b));
				break;
			}
		}
	}
	if (!found) {	// no answer
		vi answers(N, 0);
		return answers;
	}
	set<int> rem;	// all remaining, forms last set
	for (int i = 0; i < N; i++) {
		rem.insert(i);
	}
	for (int i : ans.fi)
		rem.erase(i);
	for (int i : ans.se)
		rem.erase(i);

	int x = ans.fi.size();
	int y = ans.se.size();
	int z = rem.size();
	set<int> ans1;	// size a
	set<int> ans2;	// size b
	set<int> ans3;	// size c
	if (a == x && b == y && c == z) {
		ans1 = ans.fi;
		ans2 = ans.se;
		ans3 = rem;
	}
	if (a == x && b == z && c == x) {
		ans1 = ans.fi;
		ans2 = rem;
		ans3 = ans.se;
	}
	if (a == y && b == x && c == z) {
		ans1 = ans.se;
		ans2 = ans.fi;
		ans3 = rem;
	}
	if (a == y && b == z && c == x) {
		ans1 = ans.se;
		ans2 = rem;
		ans3 = ans.fi;
	}
	if (a == z && b == x && c == y) {
		ans1 = rem;
		ans2 = ans.fi;
		ans3 = ans.se;
	}
	if (a == z && b == y && c == x) {
		ans1 = rem;
		ans2 = ans.se;
		ans3 = ans.fi;
	}
	vi answers(N);
	for (int i : ans1)
		answers[i] = 1;
	for (int i : ans2)
		answers[i] = 2;
	for (int i : ans3)
		answers[i] = 3;
	return answers;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...