Submission #1081059

#TimeUsernameProblemLanguageResultExecution timeMemory
1081059AmirAli_H1Beech Tree (IOI23_beechtree)C++17
71 / 100
147 ms56660 KiB
// In the name of Allah

#include <bits/stdc++.h>
#include "beechtree.h"
using namespace std;

typedef			long long				ll;
typedef			pair<int, int>			pii;
typedef			pair<ll, ll>			pll;

#define			F						first
#define			S						second
#define			endl					'\n'
#define			sep						' '
#define			pb						push_back
#define			Mp						make_pair
#define			all(x)					(x).begin(),(x).end()
#define			len(x)					((ll) (x).size())

const int maxn = 2e5 + 4;
const int maxs = 2000 + 4;
const int maxlg = 12;

int n, m;
int par[maxn], col[maxn];
vector<pii> adj[maxn];
int h[maxn]; vector<int> arr;
vector<int> ls1, ls2;
int val[maxs][maxs];
int M[maxn]; vector<int> ans;
int st[maxn], ft[maxn], timer;
int up[maxn][maxlg], cnt[maxn];
int mxh[maxn], mxcnt; map<int, int> cntx;

bool cmp1(int i, int j) {
	return (h[i] > h[j]);
}

bool cmp2(int i, int j) {
	return (Mp(len(adj[i]), -h[i]) > Mp(len(adj[j]), -h[j]));
}

bool is_gr(int v, int u) {
	return (st[v] <= st[u] && ft[v] >= ft[u]);
}

int lca(int u, int v) {
	if (is_gr(u, v)) return u;
	if (is_gr(v, u)) return v;
	for (int i = maxlg - 1; i >= 0; i--) {
		if (!is_gr(up[v][i], u)) v = up[v][i];
	}
	return up[v][0];
}

void pre_dfs(int v) {
	up[v][0] = (par[v] != -1) ? par[v] : v;
	for (int i = 1; i < maxlg; i++) up[v][i] = up[up[v][i - 1]][i - 1];
	
	st[v] = timer++; mxh[v] = 0;
	for (auto f : adj[v]) {
		int u = f.S;
		h[u] = h[v] + 1;
		pre_dfs(u);
		mxh[v] = max(mxh[v], mxh[u] + 1);
	}
	ft[v] = timer;
}

void dfs(int v) {
	for (auto f : adj[v]) {
		int u = f.S;
		dfs(u);
		M[v] += M[u];
	}
}

void setx(int v, vector<int> & ls) {
	for (auto f : adj[v]) {
		int c = f.F, u = f.S;
		if (ls[c] == -1) ls[c] = u;
		else ls[c] = -2;
	}
}

void setd(int v, vector<int> & ls) {
	for (auto f : adj[v]) {
		int c = f.F, u = f.S;
		ls[c] = -1;
	}
}

bool ok(int u, int v) {
	for (auto f : adj[u]) {
		int c = f.F;
		if (ls1[c] == -2) return 0;
	}
	for (auto f : adj[v]) {
		int c = f.F;
		if (ls2[c] == -2) return 0;
	}
	
	int T = 0; bool flag = 1;
	for (auto f : adj[u]) {
		int c = f.F;
		if (ls2[c] == -1) {
			flag = 0;
			break;
		}
	}
	T += flag;
	
	flag = 1;
	for (auto f : adj[v]) {
		int c = f.F;
		if (ls1[c] == -1) {
			flag = 0;
			break;
		}
	}
	T += flag;
	
	return (T >= 1);
}

void solve1() {
	arr.clear(); arr.resize(n);
	iota(all(arr), 0); sort(all(arr), cmp1);
	
	for (int i1 = 0; i1 < len(arr); i1++) {
		for (int i2 = i1 + 1; i2 < len(arr); i2++) {
			int u = arr[i1], v = arr[i2];
			setx(u, ls1); setx(v, ls2);
			if (!ok(u, v)) {
				val[u][v] = val[v][u] = -23;
			}
			else {
				if (len(adj[u]) > len(adj[v])) {
					val[u][v] = 1; val[v][u] = -1;
				}
				else if (len(adj[v]) > len(adj[u])) {
					val[u][v] = -1; val[v][u] = 1;
				}
				else {
					val[u][v] = val[v][u] = 0;
				}
				for (auto f : adj[u]) {
					int c = f.F;
					int ux = ls1[c], vx = ls2[c];
					if (vx == -1) continue;
					if (val[ux][vx] == -23 || val[u][v] == -23) {
						val[u][v] = val[v][u] = -23;
					}
					else if (val[ux][vx] == 1) {
						if (val[u][v] == -1) val[u][v] = val[v][u] = -23;
						else {
							val[u][v] = 1; val[v][u] = -1;
						}
					}
					else if (val[ux][vx] == -1) {
						if (val[u][v] == 1) val[u][v] = val[v][u] = -23;
						else {
							val[u][v] = -1; val[v][u] = 1;
						}
					}
				}
			}
			if (val[u][v] == -23) {
				int r = lca(u, v);
				M[r]++;
			}
			setd(u, ls1); setd(v, ls2);
		}
	}
}

void dfsx(int v) {
	arr.pb(v);
	for (auto f : adj[v]) {
		int u = f.S;
		dfsx(u);
	}
}

bool okx() {
	for (int i = 1; i < len(arr); i++) {
		int v1 = arr[i - 1], v2 = arr[i];
		setx(v1, ls1);
		for (auto f : adj[v2]) {
			int c = f.F;
			if (ls1[c] == -1) return 0;
		}
		setd(v1, ls1);
	}
	return 1;
}

void solvex() {
	for (int v = 0; v < n; v++) {
		setx(v, ls1);
		for (auto f : adj[v]) {
			int c = f.F;
			if (ls1[c] == -2) M[v]++;
		}
		for (auto f : adj[v]) {
			int u = f.S;
			for (auto g : adj[u]) {
				int c = g.F;
				if (ls1[c] == -1) M[v]++;
			}
		}
		setd(v, ls1);
	}
	for (int v = 0; v < n; v++) {
		if (mxh[v] <= 2) {
			arr.clear();
			dfsx(v);
			sort(all(arr), cmp2);
			if (!okx()) M[v]++;
		}
		else if (mxcnt <= 2) M[v]++;
	}
}

vector<int> beechtree(int Nx, int Mx, vector<int> Px, vector<int> Cx) {
	n = Nx; m = Mx;
	for (int i = 0; i < n; i++) Cx[i]--;
	for (int i = 0; i < n; i++) {
		par[i] = Px[i]; col[i] = Cx[i];
		cntx[col[i]]++; mxcnt = max(mxcnt, cntx[col[i]]);
	}
	for (int i = 1; i < n; i++) {
		adj[par[i]].pb(Mp(col[i], i));
	}
	for (int i = 0; i < n; i++) sort(all(adj[i]));
	
	pre_dfs(0);
	
	ls1.resize(m); ls2.resize(m);
	fill(all(ls1), -1); fill(all(ls2), -1);
	
	if (n <= 2000) solve1();
	else solvex();
	
	dfs(0);
	ans.resize(n);
	for (int i = 0; i < n; i++) {
		if (M[i] == 0) ans[i] = 1;
		else ans[i] = 0;
	}
	return ans;
}

Compilation message (stderr)

beechtree.cpp: In function 'void setd(int, std::vector<int>&)':
beechtree.cpp:88:16: warning: unused variable 'u' [-Wunused-variable]
   88 |   int c = f.F, u = f.S;
      |                ^
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...