Submission #1213460

#TimeUsernameProblemLanguageResultExecution timeMemory
1213460madamadam3Connecting Supertrees (IOI20_supertrees)C++20
40 / 100
121 ms22256 KiB
#include "supertrees.h"
#include <bits/stdc++.h>

using namespace std;

#define sz(x) int((x).size())
#define bg(x) (x).begin()
#define en(x) (x).end()
#define all(x) bg((x)), en((x))
#define FOR(i, a, b) for (int i = a; i < b; i++)
#define pb push_back

using vi = vector<int>;
using vvi = vector<vi>;

struct DSU {
	int n; vector<int> par, siz;

	DSU(int N) {
		n = N;
		par.resize(n); iota(all(par), 0);
		siz.assign(n, 1);
	}

	int find(int v) {
		if (par[v] == v) return v;
		return par[v] = find(par[v]);
	}

	void unite(int a, int b) {
		a = find(a); b = find(b);
		if (a != b) {
			if (siz[a] < siz[b]) swap(a, b);
			par[b] = a;
			siz[a] += siz[b];
		}
	}
};

int construct(vvi p) {
	int n = sz(p);
	auto dsu = DSU(n);

	vector<bool> all2(n, true);
	FOR(i, 0, n) FOR(j, 0, n) all2[i] = all2[i] && (i == j || (p[i][j] == 0 || p[i][j] == 2));

	for (int i = 0; i < n; i++) {
		for (int j = i+1; j < n; j++) {
			if (p[i][j] == 0) continue;
			if (p[i][j] == 3) return 0;

			if (all2[i] && all2[j]) dsu.unite(i, j);
			else if (!all2[i] && !all2[j] && p[i][j] == 1) {
				dsu.unite(i, j);
			}
		}
	}

	FOR(i, 0, n) {
		FOR(j, i+1, n) {
			if (p[i][j] == 0 && dsu.find(i) == dsu.find(j)) {
				return 0;
			}
			if (p[i][j] == 2 && dsu.siz[dsu.find(i)] <= 2) {
				return 0;
			}
		}
	}

	vvi cmps(n); FOR(i, 0, n) cmps[dsu.find(i)].pb(i);
	vi tail(n, -1), head(n, -1);

	vvi answer(n, vector<int>(n, 0));
	FOR(cmpid, 0, n) {
		if (dsu.find(cmpid) != cmpid) continue;
		if (sz(cmps[cmpid]) == 0) continue;
		if (sz(cmps[cmpid]) == 1) {
			head[cmpid] = tail[cmpid] = cmpid;
			continue;
		}

		int fst = cmps[cmpid][0], prev = cmps[cmpid][0];
		FOR(i, 1, sz(cmps[cmpid])) {
			int cur = cmps[cmpid][i];
			answer[cur][prev] = answer[prev][cur] = 1;
			prev = cur;
		}

		tail[cmpid] = fst;
		head[cmpid] = prev;
	}

	// FOR(i, 0, n) {
	// 	if (all2[i]) continue;
	// 	FOR(j, 0, n) {
	// 		if (dsu.find(i) == dsu.find(j) && p[i][j] != 1) return 0;
	// 	}
	// }

	FOR(cmpid, 0, n) {
		// if (sz(cmps[cmpid]) <= 1) continue;
		if (sz(cmps[cmpid]) == 0) continue;
		if (dsu.find(cmpid) != cmpid) continue;
		if (!all2[cmpid]) continue;

		vector<int> substituents;
		for (int i = 0; i < n; i++) {
			if (all2[i]) continue;
			if (p[cmpid][i] < 1) continue;

			substituents.push_back(dsu.find(i));
			// cout << cmpid << " has substituent " << dsu.find(i) << "\n";
		}

		sort(all(substituents));
		substituents.erase(unique(all(substituents)), en(substituents));

		int h = head[cmpid], t = tail[cmpid];

		if (sz(substituents) == 0) {
			if (sz(cmps[cmpid]) > 2) {
				answer[h][t] = answer[t][h] = 1;
			}
		} else { 
			int prev = substituents[0];
			int fst = prev;
			
			for (int i = 1; i < sz(substituents); i++) {
				answer[prev][substituents[i]] = answer[substituents[i]][prev] = 1;
				dsu.unite(prev, substituents[i]);
				prev = substituents[i];
			}

			answer[h][prev] = answer[prev][h] = 1;
			answer[t][fst] = answer[fst][t] = 1;
			dsu.unite(h, prev); dsu.unite(fst, t);
		}
	}

	cmps.assign(n, vector<int>());
	FOR(i, 0, n) cmps[dsu.find(i)].pb(i);
	vector<bool> vcmp(n, true); FOR(cmpid, 0, n) {for (auto &el : cmps[cmpid]) {vcmp[cmpid] = vcmp[cmpid] && !all2[el];}}

	FOR(cmpid, 0, n) {
		if (!vcmp[cmpid]) continue;
		if (sz(cmps[cmpid]) <= 0) continue;
		if (dsu.find(cmpid) != cmpid) continue;
		
		vector<int> substituents;
		for (int i = 0; i < n; i++) {
			if (p[cmpid][i] != 2) continue;
			if (dsu.find(i) == dsu.find(cmpid)) continue;
			if (all2[i]) continue;

			substituents.pb(dsu.find(i));
		}

		sort(all(substituents));
		substituents.erase(unique(all(substituents)), en(substituents));
		if (substituents.size() == 0) continue;
		if (sz(substituents) == 1) return 0;
		int prev = substituents[0];
		int fst = prev;
		FOR(i, 1, sz(substituents)) {
			answer[prev][substituents[i]] = answer[substituents[i]][prev] = 1;
			dsu.unite(prev, substituents[i]);
			prev = substituents[i];
		}

		answer[prev][cmpid] = answer[cmpid][prev] = 1;
		answer[fst][cmpid] = answer[cmpid][fst] = 1;
		dsu.unite(prev, cmpid); dsu.unite(fst, cmpid);
	}

	build(answer);
	return 1;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...