Submission #295061

#TimeUsernameProblemLanguageResultExecution timeMemory
295061Atill83Split the Attractions (IOI19_split)C++14
18 / 100
137 ms16120 KiB
#include "split.h"
#include <bits/stdc++.h>
using namespace std;
const int MAXN = (int) 3e5 + 5;
int N, m;
vector<int> res;
vector<int> adj[MAXN];
bool visited[MAXN];
int kal[3];
int sz[MAXN];
void dfs1(int v){
	visited[v] = 1;
	if(kal[1] == 0) return;
	res[v] = 2;
	kal[1]--;
	if(kal[1] == 0) return;
	for(int j: adj[v]){
		if(!visited[j]) 
			dfs1(j);
	}
}
int node = -1, st = -1;
void dfs2(int v, int par, int k1, int k2){
	sz[v] = 1;
	for(int j: adj[v]){
		if(j != par){
			dfs2(j, v, k1, k2);
			sz[v] += sz[j];
		}
	}
	int up = N - sz[v];

	if(up >= kal[k1] && sz[v] >= kal[k2]){
		node = v;
		st = 0;
	}else if(up >= kal[k2] && sz[v] >= kal[k1]){
		node = v;
		st = 1;
	}
}

void dfs4(int v, int par, int k){
	if(kal[k] == 0) return;
	res[v] = k + 1;
	kal[k]--;
	if(kal[k] == 0) return;
	for(int j: adj[v]){
		if(j != par) dfs4(j, v, k);
	}
}


void dfs3(int v, int par, int k1, int k2){
	if(node == v){
		if(st == 0){
			dfs4(par, v, k1);
			res[v] = k2 + 1;
			kal[k2]--;
		}else{ 
			dfs4(par, v, k2);
			res[v] = k1 + 1;
			kal[k1]--;
		}
		for(int j: adj[v]){
			if(st)
				dfs4(j, v, k1);
			else dfs4(j, v, k2);
		}
		return;
	}
	for(int j: adj[v]){
		if(j != par){
			dfs3(j, v, k1, k2);
		}
	}
}
int sim = 0;
void dfs5(int v){
	if(kal[sim] == 0) sim++;
	res[v] = sim + 1;
	kal[sim]--;
	for(int j: adj[v]){
		if(!res[j]){
			dfs5(j);
		}
	}
}



vector<int> find_split(int n, int a, int b, int c, vector<int> p, vector<int> q) {
	N = n;
	res.resize(n, 0);
	m = p.size();
	int mx = 0;
	for(int i = 0; i < m; i++){
		adj[p[i]].push_back(q[i]);
		adj[q[i]].push_back(p[i]);
	}
	for(int i = 0; i < n; i++) mx = max(mx, (int)adj[i].size());
	kal[0] = a;
	kal[1] = b;
	kal[2] = c;
	if(a == 1){
		dfs1(0);
		bool done = 0;
		for(int i = 0; i < n; i++){
			if(res[i]) continue;
			if(done) res[i] = 3;
			else{
				res[i] = 1;
				done = 1;
			}
		}
	}else if(mx <= 2){
		dfs5(0);
	}else if(m == n - 1){
		vector<int> ali = {0, 1, 2};
		sort(ali.begin(), ali.end(), [](int a, int b){
			return kal[a] < kal[b];
		});
		dfs2(0, -1, ali[0], ali[1]);
		if(node != -1){
			dfs3(0, -1, ali[0], ali[1]);
			for(int i = 0; i < n; i++) 
				if(res[i] == 0) 
					res[i] = ali[2] + 1;
		}
	}

	return res;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...