제출 #1172240

#제출 시각아이디문제언어결과실행 시간메모리
1172240gyg열쇠 (IOI21_keys)C++20
37 / 100
3100 ms154720 KiB
#include "keys.h"
#include <bits/stdc++.h>
using namespace std;
#define vec vector
#define arr array 
#define pii pair<int, int>
#define fir first 
#define sec second 
const int N = 3e5 + 5, INF = 1e9;

// Zero indexing

int n;
arr<int, N> cl;
arr<vec<pii>, N> adj;

struct Dsj {
	arr<int, N> pr;
	arr<vec<int>, N> st;
	void intl() {
		for (int u = 0; u < n; u++)
			pr[u] = u, st[u] = {u};
	}
	void mrg(int u, int v) { // v into u
		u = pr[u], v = pr[v];
		assert(u != v);
		for (int w : st[v])
			pr[w] = u, st[u].push_back(w);
	}
} dsj;

arr<set<int>, N> vs;
arr<map<int, vec<int>>, N> ngh;
arr<set<int>, N> cls;
arr<vec<int>, N> ord;
int mv(int st) {
	if (dsj.pr[st] != st) return 0;
	if (ord[st].empty()) return 0;
	int u = ord[st].back(); ord[st].pop_back();
	
	for (pii x : adj[u]) {
		int v = x.fir, c = x.sec;
		if (cls[st].count(c)) {
			if (dsj.pr[v] != dsj.pr[st]) {
				dsj.mrg(v, st);
				return 1;
			}
			if (vs[st].count(v)) continue;
			vs[st].insert(v);
			ord[st].push_back(v);
		} else {
			ngh[st][c].push_back(v);
		}
	}

	if (cls[st].count(cl[u])) return 1;
	cls[st].insert(cl[u]);
	for (int v : ngh[st][cl[u]]) {
		if (dsj.pr[v] != dsj.pr[st]) {
			dsj.mrg(v, st);
			return 1;
		}
		if (vs[st].count(v)) continue;
		vs[st].insert(v);
		ord[st].push_back(v);
	}
	return 1;
}
arr<int, N> cnt;
void cnt_cmp() {
	dsj.intl();
	for (int u = 0; u < n; u++) {
		vs[u].insert(u);
		ord[u].push_back(u);
	}

	while (true) {
		int cnt = 0;
		for (int st = 0; st < n; st++) 	
			cnt += mv(st);
		if (cnt == 0) break;
	}

	for (int u = 0; u < n; u++) {
		int v = dsj.pr[u];
		if (!vs[v].count(u)) cnt[u] = INF;
		else cnt[u] = vs[v].size();
	}
}

vec<int> find_reachable(vec<int> _cl, vec<int> _u, vec<int> _v, vec<int> _c) {
	n = _cl.size();
	for (int u = 0; u < n; u++) cl[u] = _cl[u];
	for (int i = 0; i < _u.size(); i++) {
		int u = _u[i], v = _v[i], c = _c[i];
		adj[u].push_back({v, c});
		adj[v].push_back({u, c});
	}

	cnt_cmp();

	// for (int u = 0; u < n; u++) {
	// 	cout << u << ": " << cnt[u] << '\n';
	// }

	vec<int> ans(n);
	int mn = *min_element(cnt.begin(), cnt.begin() + n);
	for (int u = 0; u < n; u++)	
		if (cnt[u] == mn) ans[u] = 1;
	return ans;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...