제출 #1047658

#제출 시각아이디문제언어결과실행 시간메모리
1047658mychecksedadSplit the Attractions (IOI19_split)C++17
100 / 100
163 ms57880 KiB
#include "split.h"
#include<bits/stdc++.h>
using namespace std;
#define ll long long int
#define pb push_back
#define ff first
#define all(x) x.begin(), x.end()
#define ss second
const int N = 5e5;

int n;
vector<int> g[N];
vector<int> res;
vector<bool> vis;
map<pair<int, int>, bool> edge;
int par[N];

void fill(int v, int par, int col, int &colnum){
	queue<int> q;
	q.push(v);
	if(colnum==0)return;
	vis[v] = 1;
	vis[par] = 1;
	res[v] = col;
	--colnum;
	while(!q.empty()){
		int v = q.front();
		q.pop();
		// cout << v << ' ' << a << ' ' << b << ' ' << c << '\n';
		
		for(int u: g[v]){
			if(!vis[u] && edge[{u,v}]){
				if(colnum > 0){
					res[u] = col;
					--colnum;
					// cout << v << ' ' << 2 << '\n';
				}else break;
				vis[u] = 1;
				q.push(u);
			}
		}
		if(colnum == 0) break;
	}
	vis[par] = 0;
}


vector<pair<int, int>> arr;
vector<bool> viss;
int bad = -1;
int sz[N], tin[N], tout[N], timer = 0, dep[N];
void dfs(int v, int p){
	sz[v] = 1;
	par[v] = p;
	dep[v] = dep[p] + 1;
	tin[v] = timer++;
	viss[v] = 1;
	for(int u: g[v]){
		if(u != p && !viss[u]){
			edge[{u,v}] = edge[{v,u}]=1;
			dfs(u, v);
			sz[v] += sz[u];
		}
	}
	if(sz[v] >= arr[0].ff){
		if(bad == -1 || dep[bad] < dep[v]){
			bad = v;
		}
	}
	tout[v] = timer++;
}
bool is_ancestor(int u, int v){
	return tin[u] <= tin[v] && tout[v] <= tout[u];
}
vector<int> ress;
void pre(int v, int col){
	vis[v] = 1;
	for(int u: g[v]){
		if(!vis[u] && res[u] == col){
			dfs(u, col);
		}
	}
}


vector<int> f(int nn, int a, int b, int c, vector<int> Y, vector<int> X) {
	n = nn;

	arr.pb({a, 1});
	arr.pb({b, 2});
	arr.pb({c, 3});
	sort(all(arr));

	for(int i = 1; i < Y.size(); ++i){
		int r = rand() % i;
		swap(Y[i], Y[r]);
		swap(X[i], X[r]);
	}
	for(int i = 0; i < Y.size(); ++i){
		g[Y[i]].pb(X[i]);
		g[X[i]].pb(Y[i]);
	}
	// assert(max({a,b,c}) != 2009);

	int root = 0;
	res.clear();
	res.resize(n+1);
	vis.clear();
	vis.resize(n+1);
	viss.clear();
	viss.resize(n+1);
	dfs(root, n);

	// if(arr[2].ff == 2009) assert(false);

	bool ok = 0;
	int A = 0;
	for(int i = 0; i < n; ++i){
		if(i == root) continue;
		if(sz[i] >= arr[0].ff && sz[root] - sz[i] >= arr[1].ff){
			// if(arr[2].ff==2009)assert(false);
			fill(i, par[i], arr[0].ss, arr[0].ff);
			fill(root, par[root], arr[1].ss, arr[1].ff);
			ok = 1;
			break;
		}else if(sz[i] >= arr[1].ff && sz[root] - sz[i] >= arr[0].ff){
			// if(arr[2].ff==2009)assert(false);
			fill(i, par[i], arr[1].ss, arr[1].ff);
			fill(root, par[root], arr[0].ss, arr[0].ff);
			ok = 1;
			break;
		}
		if(sz[i] >= arr[0].ff) A = 1;
	}
	if(!ok && A == 0){
		res.pop_back();
		return res;
	}
	if(ok){
		res.pop_back();
		for(int &cc: res) if(cc == 0) cc = arr[2].ss;
		return res;
	}
	// if(c == 2009) assert(false);

	vector<int> nodes;
	for(int j = 1; j < n; ++j) nodes.pb(j);
	sort(all(nodes), [&](const int &x, const int &y){
		return sz[x] > sz[y];
	});

	set<pair<int,int>> R;
	int v = bad;
	if(sz[v] >= arr[0].ff){ // v kotu node
		vector<pair<int, int>> good;
		int tot = 0;
		for(int u: nodes){
			if(is_ancestor(v, u) && u != v && sz[u] < arr[0].ff){ // kendi iyi, parenti kotu
				bool k = 0;
				for(int j: g[u]){
					if(is_ancestor(v, j) == 0) k = 1;
				}
				if(k){
					// for(auto x: good){
					// 	if(is_ancestor(x.ss, u)){ k = 0; break;}
					// }
					auto it = R.lower_bound(pair<int,int>{tin[u],-1});
					if(it == R.begin() || (*prev(it)).second < tin[u]){
						R.insert({tin[u], tout[u]});
						good.pb({sz[u], u});
						tot += sz[u];
					}
				}
			}
		}
		if(sz[root] - (sz[v] - tot) >= arr[0].ff){
			sort(all(good), greater<pair<int,int>>());
			vector<int> U;
			for(auto p: good){
				if(sz[root] - sz[v] < arr[0].ff){
					sz[v] -= p.ff;
					U.pb(p.ss);
				}else break;
			}
			if(sz[root] - sz[v] >= arr[0].ff && sz[v] >= arr[1].ff){
				// if(c == 2009) assert(false);
				for(int j = 0; j < n; ++j){
					if(!is_ancestor(v, j)){
						res[j] = arr[0].ss;
						vis[j] = 1;
						arr[0].ff--;
					}
				}
				for(int u: U){
					if(arr[0].ff > 0)
						fill(u, par[u], arr[0].ss, arr[0].ff);
				}
				fill(v, par[v], arr[1].ss, arr[1].ff);
				// fill(root, par[root], arr[0].ss, arr[0].ff);
			}else if(sz[root] - sz[v] >= arr[1].ff && sz[v] >= arr[0].ff){
				// if(c == 2009) assert(false);
				for(int j = 0; j < n; ++j){
					if(!is_ancestor(v, j)){
						res[j] = arr[1].ss;
						vis[j] = 1;
						--arr[1].ff;
					}
				}
				for(int u: U){
					if(arr[1].ff > 0)
						fill(u, par[u], arr[1].ss, arr[1].ff);
				}
				fill(v, par[v], arr[0].ss, arr[0].ff);
			}else{
				assert(false);
			}
			
			ok = 1;
		}
	}
	if(!ok){
		res.pop_back();
		return res;
	}

	res.pop_back();
	for(int &cc: res) if(cc == 0) cc = arr[2].ss;
	return res;


	

	
	


	for(int root = 0; root < n; ++root){
		res.clear();
		res.resize(n+1);
		vis.clear();
		vis.resize(n+1);
		viss.clear();
		viss.resize(n+1);
		dfs(root, n);

	

		bool ok = 0;
		for(int i = 0; i < n; ++i){
			if(i == root) continue;
			if(sz[i] >= arr[0].ff && sz[root] - sz[i] >= arr[1].ff){
				fill(i, par[i], arr[0].ss, arr[0].ff);
				fill(root, par[root], arr[1].ss, arr[1].ff);
				ok = 1;
				break;
			}else if(sz[i] >= arr[1].ff && sz[root] - sz[i] >= arr[0].ff){
				fill(i, par[i], arr[1].ss, arr[1].ff);
				fill(root, par[root], arr[0].ss, arr[0].ff);
				ok = 1;
				break;
			}
		}
		if(!ok){
			if(X.size() == nn-1) break;
			continue;
		}
		res.pop_back();
		for(int &cc: res) if(cc == 0) cc = arr[2].ss;
		return res;
	}
	res.pop_back();
	return res;
}
vector<int> find_split(int nn, int a, int b, int c, vector<int> Y, vector<int> X) {
	ress = f(nn, a, b, c, Y, X);
	if(ress[0] == 0) return ress;
	if(nn <= 3000){
		vis.clear();
		vis.resize(n);
		int co = 0, A = 0, B= 0, C = 0;
		for(int j = 0 ;j < n; ++j){
			if(res[j] == 1) ++A;
			if(res[j] == 2) ++B;
			if(res[j] == 3) ++C;
			if(vis[j] == 0 && ress[j] != arr[2].ss){
				dfs(j, res[j]);
			}
		}
		// if(a != A) assert(false);
		if(co > 3) assert(false);
	}
	return ress;
}

컴파일 시 표준 에러 (stderr) 메시지

split.cpp: In function 'std::vector<int> f(int, int, int, int, std::vector<int>, std::vector<int>)':
split.cpp:94:19: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
   94 |  for(int i = 1; i < Y.size(); ++i){
      |                 ~~^~~~~~~~~~
split.cpp:99:19: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
   99 |  for(int i = 0; i < Y.size(); ++i){
      |                 ~~^~~~~~~~~~
split.cpp:264:16: warning: comparison of integer expressions of different signedness: 'std::vector<int>::size_type' {aka 'long unsigned int'} and 'int' [-Wsign-compare]
  264 |    if(X.size() == nn-1) break;
      |       ~~~~~~~~~^~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...