제출 #596169

#제출 시각아이디문제언어결과실행 시간메모리
596169AriaHSplit the Attractions (IOI19_split)C++17
22 / 100
500 ms80144 KiB
#include "split.h"
#pragma GCC optimize("O3")

#include <bits/stdc++.h>

using namespace std;

typedef long long ll;
typedef long double ld;
typedef pair < int, int > pii;
typedef pair < ll, ll > pll;

#define F first
#define S second
#define all(x) x.begin(), x.end()
#define SZ(x) (int)x.size()
#define Mp make_pair
#define endl "\n"
#define fast_io ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);

const int N = 1e6 + 10;
const int LOG = 20;
const ll mod = 1e9 + 7;
const ll inf = 8e18;

vector < int > ret;

set < int > G[N];

int n, m, a, b, c, Can, mark[N], sub[N];

void pre(int v, int P)
{
	mark[v] = 1;
	sub[v] = 1;
	vector < int > del;
	for(auto u : G[v])
	{
		if(u == P) continue;
		if(mark[u])
		{
			del.push_back(u);
			continue;
		}
		pre(u, v);
		sub[v] += sub[u];
	}
	if(!Can) return;
	for(auto x : del)
	{
		if(G[v].find(x) != G[v].end())
		{
			G[v].erase(x);
		}
	}
}

int find(int v, int P)
{
	mark[v] = 1;
	for(auto u : G[v])
	{
		if(u == P) continue;
		if(mark[u]) continue;
		if(sub[u] * 2 > n) return find(u, v);
	}
	return v;
}

set < pii > st;

void dfs(int v, int P, int col)
{
	st.insert(Mp(SZ(G[v]), v));
	ret[v - 1] = col;
	for(auto u : G[v])
	{
		if(u == P) continue;
		dfs(u, v, col);
	}
}

int calc(int x)
{
	if(x == a) return 1;
	if(x == b) return 2;
	return 3;
}

ll rand(ll l, ll r)
{
	return rand() % (r - l + 1) + l;
}

vector < int > find_split(int _n, int _a, int _b, int _c, vector < int > p, vector < int > q)
{
	n = _n;
	a = _a;
	b = _b;
	c = _c;
	m = SZ(p);
	for(int i = 0; i < m; i ++)
	{
		p[i] ++, q[i] ++;
		G[p[i]].insert(q[i]);
		G[q[i]].insert(p[i]);
	}
	int St = rand(1, n);
	Can = 0;
	pre(St, 0);
	memset(mark, 0, sizeof mark);
	int cen = find(St, 0);
	memset(mark, 0, sizeof mark);
	Can = 1;
	pre(cen, 0);
	int B = 0;
	for(auto u : G[cen])
	{
		if(sub[u] > sub[B]) B = u;
	}
	ret.resize(n, 0);
	if(min({a, b, c}) > sub[B])
	{
		return ret;
	}
	G[B].erase(cen);
	G[cen].erase(B);
	int Mn = min({a, b, c});
	int col = calc(Mn);
	dfs(B, cen, col);
	while(SZ(st) > Mn)
	{
		int node = st.begin()->S;
		st.erase(st.begin());
		ret[node - 1] = 0;
		for(auto u : G[node])
		{
			st.erase(Mp(SZ(G[u]), u));
			G[u].erase(node);
			st.insert(Mp(SZ(G[u]), u));
		}
		G[node].clear();
	}
	int Mn2 = a + b + c - Mn - max({a, b, c});
	int col2 = -1;
	if(a == Mn2 && col != 1)
	{
		col2 = 1;
	}
	else if(b == Mn2 && col != 2)
	{
		col2 = 2;
	}
	else if(c == Mn2 && col != 3)
	{
		col2 = 3;
	}
	st.clear();
	dfs(cen, B, col2);
	while(SZ(st) > Mn2)
	{
		int node = st.begin()->S;
		st.erase(st.begin());
		ret[node - 1] = 0;
		for(auto u : G[node])
		{
			st.erase(Mp(SZ(G[u]), u));
			G[u].erase(node);
			st.insert(Mp(SZ(G[u]), u));
		}
		G[node].clear();
	}
	for(int i = 0; i < n; i ++)
	{
		if(ret[i] == 0)
		{
			ret[i] = 1 + 2 + 3 - col - col2;
		}
	}
	return ret;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...