제출 #722529

#제출 시각아이디문제언어결과실행 시간메모리
722529ymmSplit the Attractions (IOI19_split)C++17
18 / 100
218 ms56772 KiB
#include "split.h"
#include <bits/stdc++.h>
#define Loop(x,l,r) for (ll x = (l); x < (ll)(r); ++x)
#define LoopR(x,l,r) for (ll x = (r)-1; x >= (ll)(l); --x)
typedef long long ll;
typedef std::pair<int, int> pii;
typedef std::pair<ll , ll > pll;
using namespace std;

const int N = 100'010;
vector<int> A[N];
vector<int> C[N];
int sz[N], mn[N], height[N];;

bool vis[N];
void dfs0(int v, int h)
{
	vis[v] = 1;
	mn[v] = height[v] = h;
	sz[v] = 1;
	for (int u : A[v]) {
		if (vis[u]) {
			mn[v] = min(mn[v], height[u]);
			continue;
		}
		C[v].push_back(u);
		dfs0(u, h+1);
		sz[v] += sz[u];
		mn[v] = min(v, mn[u]);
	}
}

bool is_in[N];
int add(int v, int rt) {
	if (is_in[v] || mn[v] >= height[rt])
		return 0;
	int ans = sz[v];
	for (int u : C[v]) {
		if (is_in[u]) {
			ans -= sz[u];
			is_in[u] = 0;
		}
	}
	is_in[v] = 1;
	return ans;
}
int rem(int v, int rt) {
	if (!is_in[v])
		return 0;
	is_in[v] = 0;
	return sz[v];
}

int n, m;

void merge(set<pii> &a, set<pii> &b)
{
	if (a.size() < b.size())
		a.swap(b);
	for (auto x : b)
		a.insert(x);
	b.clear();
}

int dfs1(int v, int sz1, int sz2, set<pii> &by_mn, set<pii> &by_sz, int &sum)
{
	by_mn = {{mn[v], v}};
	by_sz = {{sz[v], v}};
	sum = 0;
	for (int u : C[v]) {
		set<pii> x, y;
		int z;
		int ret;
		ret = dfs1(u, sz1, sz2, x, y, z);
		if (ret != -1)
			return ret;
		merge(by_mn, x);
		merge(by_sz, y);
		sum += z;
	}
	while (by_sz.size()) {
		int u = by_sz.begin()->second;
		if (sz[v] - sz[u] < sz2)
			break;
		sum += add(u, v);
		by_sz.erase(by_sz.begin());
	}
	while (by_mn.size()) {
		int u = by_mn.begin()->second;
		if (mn[u] > height[v])
			break;
		sum -= rem(u, v);
		by_mn.erase(by_mn.begin());
	}
	if (sz[v] >= sz2 && n - sz[v] + sum >= sz1)
		return v;
	return -1;
}

void dfs2(int v, int sz_rt, int sz_target, vector<int> &vec)
{
	if (sz_rt - sz[v] >= sz_target) {
		vec.push_back(v);
		return;
	}
	for (int u : C[v]) {
		dfs2(u, sz_rt, sz_target, vec);
	}
}

void dfs_col(int v, vector<int> &col, int c)
{
	col[v] = c;
	for (int u : C[v])
		dfs_col(u, col, c);
}

void dfs3(int v, vector<int> &col, int c, int &rem)
{
	if (!rem)
		return;
	vis[v] = 1;
	col[v] = c;
	--rem;
	for (int u : A[v]) {
		if (vis[u])
			continue;
		dfs3(u, col, c, rem);
	}
}

vector<int> solve(int a, int b, int sa, int sb, int sc)
{
	set<pii> x, y;
	int z;
	memset(is_in, 0, sizeof(is_in));
	int v = dfs1(0, a, b, x, y, z);
	if (v == -1)
		return {};
	vector<int> vec;
	dfs2(v, sz[v], b, vec);
	sort(vec.begin(), vec.end(), [](int i, int j) {
		return sz[i] < sz[j];
	});
	vector<int> col(n);
	dfs_col(0, col, -1);
	dfs_col(v, col, -2);
	int sza = n - sz[v], szb = sz[v];
	while (vec.size() && sza < a) {
		int u = vec.back();
		vec.pop_back();
		sza += sz[u];
		szb -= sz[u];
		dfs_col(u, col, -1);
	}
	assert(sza >= a);
	assert(szb >= b);
	Loop (i,0,n)
		vis[i] = col[i] != -1;
	dfs3(0, col, 0, a);
	assert(a == 0);
	Loop (i,0,n)
		vis[i] = col[i] != -2;
	dfs3(v, col, 1, b);
	assert(b == 0);
	Loop (i,0,n) {
		if (col[i] < 0)
			col[i] = 2;
	}
	Loop (i,0,n)
		col[i] = vector<int>{sa, sb, sc}[col[i]];
	return col;
}

vector<int> find_split(int _n, int a, int b, int c, vector<int> p, vector<int> q)
{
	n = _n;
	Loop (i,0,p.size()) {
		int v = p[i], u = q[i];
		A[v].push_back(u);
		A[u].push_back(v);
	}
	dfs0(0, 0);
	int sa = 1, sb = 2, sc = 3;
	if (a > b) {
		swap(a, b);
		swap(sa, sb);
	}
	if (b > c) {
		swap(b, c);
		swap(sb, sc);
	}
	if (a > b) {
		swap(a, b);
		swap(sa, sb);
	}
	vector<int> ans;
	if ((ans = solve(a, b, sa, sb, sc)).size())
		return ans;
	if ((ans = solve(b, a, sb, sa, sc)).size())
		return ans;
	return vector<int>(n, 0);
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...