Submission #144864

#TimeUsernameProblemLanguageResultExecution timeMemory
144864eriksuenderhaufSplit the Attractions (IOI19_split)C++14
100 / 100
521 ms67944 KiB
//#pragma GCC optimize("O3")
#include <bits/stdc++.h>
#include "split.h"
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#include <ext/rope>
#define mem(a,v) memset((a), (v), sizeof (a))
#define enl printf("\n")
#define case(t) printf("Case #%d: ", (t))
#define ni(n) scanf("%d", &(n))
#define nl(n) scanf("%I64d", &(n))
#define nai(a, n) for (int i = 0; i < (n); i++) ni(a[i])
#define nal(a, n) for (int i = 0; i < (n); i++) nl(a[i])
#define pri(n) printf("%d\n", (n))
#define prl(n) printf("%I64d\n", (n))
#define pii pair<int, int>
#define pil pair<int, long long>
#define pll pair<long long, long long>
#define vii vector<pii>
#define vil vector<pil>
#define vll vector<pll>
#define vi vector<int>
#define vl vector<long long>
#define pb push_back
#define mp make_pair
#define fi first
#define se second
using namespace std;
using namespace __gnu_pbds;
typedef long long ll;
typedef cc_hash_table<int,int,hash<int>> ht;
typedef tree<int,null_type,less<int>,rb_tree_tag,tree_order_statistics_node_update> oset;
const double pi = acos(-1);
const int MOD = 1e9 + 7;
const int INF = 1e9 + 7;
const int MAXN = 1e6 + 5;
const double eps = 1e-9;
int par[MAXN], sz[MAXN], heavy[MAXN];
int par2[MAXN];
vi adj[MAXN], adj2[MAXN];

int qry(int x) { return par[x] == x ? x : par[x] = qry(par[x]); }
void join(int u, int v) { par[qry(u)] = qry(v); }

vi f(int x, int y, int z) {
	vi ret = {1,2,3};
	vi cur = {x,y,z};
	do {
		if (cur[ret[0]-1] <= cur[ret[1]-1] && cur[ret[1]-1] <= cur[ret[2]-1])
			break;
	} while (next_permutation(ret.begin(),ret.end()));
	return ret;
}

void getSz(int u, int p, int n) {
	sz[u] = 1;
	for (int v: adj[u]) {
		if (v == p)
			continue;
		getSz(v, u, n);
		sz[u] += sz[v];
		if (sz[v] > heavy[u])
			heavy[u] = sz[v];
	}
	if (n - sz[u] > heavy[u])
		heavy[u] = n - sz[u];
	if ((n+1) / 2 < heavy[u])
		heavy[u] = -1;
}

vii inact;

int qry2(int x) { return par2[x] == x ? x : par2[x] = qry2(par2[x]); }
void join2(int u, int v) { par2[qry2(u)] = qry2(v); }

int dfs(int u, int p, int r) {
	join2(u,r);
	int ret = 1;
	for (int v: adj[u])
		if (v != p)
			ret += dfs(v, u, r);
	return ret;
}

void flood(int u, int col, int& cnt, vi& ret, bool fl = false) {
	if (ret[u] || cnt == 0) return;
	cnt--;
	ret[u] = col;
	for (int v: adj2[u])
		if (qry2(v) == qry2(u) || fl)
			flood(v, col, cnt, ret, fl);
}

bool check(int root, int n, vi& ret, vi& col, vi& tmp) {
	for (int i = 0; i < n; i++) par2[i] = i;
	set<pii> childsz;
	set<pii> rev;
	for (int v: adj[root]) {
		int tmpsz = dfs(v, root, v);
		childsz.insert(mp(tmpsz, v));
		rev.insert(mp(v, tmpsz));
	}
	for (int i = 0; i < inact.size() && tmp[0] > (*childsz.rbegin()).fi; i++) {
		int u = qry2(inact[i].fi), v = qry2(inact[i].se);
		if (u == v || u == root || v == root)
			continue;
		auto it1 = *rev.lower_bound(mp(u,-1));
		auto it2 = *rev.lower_bound(mp(v,-1));
		childsz.erase(mp(it1.se,it1.fi));
		childsz.erase(mp(it2.se,it2.fi));
		int nsz = it1.se+it2.se;
		rev.erase(it1);
		rev.erase(it2);
		par2[u] = v;
		childsz.insert(mp(nsz,v));
		rev.insert(mp(v,nsz));
	}
	if (tmp[0] > (*childsz.rbegin()).fi)
		return false;
	int cnt = tmp[0], curcol = col[0];
	flood((*childsz.rbegin()).se, curcol, cnt, ret);
	cnt = tmp[1], curcol = col[1];
	flood(root, curcol, cnt, ret, 1);
	for (int i = 0; i < n; i++)
		if (!ret[i])
			ret[i] = col[2];
	return true;
}

vi find_split(int n, int a, int b, int c, vi p, vi q) {
	for (int i = 0; i < n; i++)
		par[i] = i;
	int m = p.size();
	vi col = f(a,b,c);
	vi tmp = {a,b,c}; sort(tmp.begin(), tmp.end());
	vi ret(n);
	for (int i = 0; i < m; i++) {
		adj2[p[i]].pb(q[i]);
		adj2[q[i]].pb(p[i]);
		if (qry(p[i]) == qry(q[i]))
			inact.pb(mp(p[i],q[i]));
		else {
			join(p[i],q[i]);
			adj[p[i]].pb(q[i]);
			adj[q[i]].pb(p[i]);
		}
	}
	getSz(0, -1, n);
	vi cent;
	int cur = -1;
	for (int i = 0; i < n; i++)
		if (heavy[i] > cur) {
			cur = heavy[i];
			cent = {i};
		} else if (heavy[i] == cur)
			cent.pb(i);
	for (int c: cent)
		if (check(c, n, ret, col, tmp))
			return ret;
	return ret;
}

Compilation message (stderr)

split.cpp: In function 'bool check(int, int, std::vector<int>&, std::vector<int>&, std::vector<int>&)':
split.cpp:103:20: warning: comparison between signed and unsigned integer expressions [-Wsign-compare]
  for (int i = 0; i < inact.size() && tmp[0] > (*childsz.rbegin()).fi; i++) {
                  ~~^~~~~~~~~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...