Submission #596198

#TimeUsernameProblemLanguageResultExecution timeMemory
596198AriaHSplit the Attractions (IOI19_split)C++17
40 / 100
471 ms78844 KiB
#include "split.h"
#pragma GCC optimize("O3")

#include <bits/stdc++.h>

using namespace std;

typedef long long ll;
typedef long double ld;
typedef pair < int, int > pii;
typedef pair < ll, ll > pll;

#define F first
#define S second
#define all(x) x.begin(), x.end()
#define SZ(x) (int)x.size()
#define Mp make_pair
#define endl "\n"
#define fast_io ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);

const int N = 1e6 + 10;
const int LOG = 20;
const ll mod = 1e9 + 7;
const ll inf = 8e18;

vector < int > ret;

set < int > G[N];

int n, m, a, b, c, par[N], sz[N], Can, mark[N], sub[N];

vector < pii > Del;

int pre(int v, int P)
{
	int ret = v;
	mark[v] = 1;
	sub[v] = 1;
	vector < int > del;
	for(auto u : G[v])
	{
		if(u == P) continue;
		if(mark[u])
		{
			del.push_back(u);
			continue;
		}
		int nxt = pre(u, v);
		if(sub[u] * 2 > n) ret = nxt;
		sub[v] += sub[u];
	}
	for(auto x : del)
	{
		if(G[v].find(x) != G[v].end())
		{
			Del.push_back(Mp(x, v));
			G[v].erase(x);
		}
	}
	return ret;
}

void solve(int v, int P, int jad)
{
	par[v] = jad;
	sz[jad] ++;
	for(auto u : G[v])
	{
		if(u == P) continue;
		solve(u, v, jad);
	}
}

set < pii > st;

void dfs(int v, int P, int col)
{
	st.insert(Mp(SZ(G[v]), v));
	ret[v - 1] = col;
	for(auto u : G[v])
	{
		if(u == P) continue;
		dfs(u, v, col);
	}
}

int calc(int x)
{
	if(x == a) return 1;
	if(x == b) return 2;
	return 3;
}

int get(int x)
{
	return (x == par[x]? x : par[x] = get(par[x]));
}

int unify(int v, int u)
{
	v = get(v), u = get(u);
	if(v == u) return 0;
	sz[v] += sz[u];
	par[u] = v;
	return 1;
}

vector < int > find_split(int _n, int _a, int _b, int _c, vector < int > p, vector < int > q)
{
	n = _n;
	a = _a;
	b = _b;
	c = _c;
	m = SZ(p);
	for(int i = 0; i < m; i ++)
	{
		p[i] ++, q[i] ++;
		G[p[i]].insert(q[i]);
		G[q[i]].insert(p[i]);
	}
	int cen = pre(1, 0);
	memset(mark, 0, sizeof mark);
	pre(cen, 0);
	for(auto u : G[cen])
	{
		solve(u, cen, u);
	}
	int B = 0;
	for(auto u : G[cen])
	{
		if(sub[u] > sub[B]) B = u;
	}
	int Mn = min({a, b, c});
	ret.resize(n, 0);
	int id = B;
	for(auto [x, y] : Del)
	{
		if(sz[id] >= Mn) break;
		if(unify(x, y))
		{
			if(sz[get(x)] > sz[id])
			{
				id = get(x);
			}
		}
	}
	if(sz[id] < Mn)
	{
		return ret;
	}
	for(int i = 1; i <= n; i ++)
	{
		if(get(i) == B)
		{
			if(G[i].find(cen) != G[i].end())
			{
				G[i].erase(cen);
				G[cen].erase(i);
			}
		}
	}
	int col = calc(Mn);
	dfs(B, cen, col);
	while(SZ(st) > Mn)
	{
		int node = st.begin()->S;
		st.erase(st.begin());
		ret[node - 1] = 0;
		for(auto u : G[node])
		{
			st.erase(Mp(SZ(G[u]), u));
			G[u].erase(node);
			st.insert(Mp(SZ(G[u]), u));
		}
		G[node].clear();
	}
	int Mn2 = a + b + c - Mn - max({a, b, c});
	int col2 = -1;
	if(a == Mn2 && col != 1)
	{
		col2 = 1;
	}
	else if(b == Mn2 && col != 2)
	{
		col2 = 2;
	}
	else if(c == Mn2 && col != 3)
	{
		col2 = 3;
	}
	st.clear();
	dfs(cen, B, col2);
	while(SZ(st) > Mn2)
	{
		int node = st.begin()->S;
		st.erase(st.begin());
		ret[node - 1] = 0;
		for(auto u : G[node])
		{
			st.erase(Mp(SZ(G[u]), u));
			G[u].erase(node);
			st.insert(Mp(SZ(G[u]), u));
		}
		G[node].clear();
	}
	for(int i = 0; i < n; i ++)
	{
		if(ret[i] == 0)
		{
			ret[i] = 1 + 2 + 3 - col - col2;
		}
	}
	return ret;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...