답안 #630494

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
630494 2022-08-16T12:27:40 Z Arnch Mergers (JOI19_mergers) C++17
컴파일 오류
0 ms 0 KB
// oooo
/*
 har chi delet mikhad bebar ~
 gitar o ba khodet nabar! ~
 ;Amoo_Hasan;
*/

#include<bits/stdc++.h>
#pragma GCC optimize("O3,no-stack-protector,unroll-loops")
#pragma GCC target("avx2,fma")

using namespace std;

typedef long long ll;
typedef long double ld;

#define Sz(x) int((x).size())
#define All(x) (x).begin(), (x).end()
#define wtf(x) cout<<#x <<" : " <<x <<endl
#define mak make_pair

//constexpr int PRI = 1000696969;
constexpr int N = 2e6 + 2, MAXN = 5e5 + 2, LOG = 20;

int n, k;
int s[MAXN], sub[MAXN], head[MAXN], link[MAXN];
int par[MAXN][LOG], h[MAXN], st[MAXN], fn[MAXN], tim;
int sz[MAXN], pv[MAXN];
int ans, total;
int seg[N];
vector<int> vc[MAXN], adj[MAXN], nei[MAXN];

void dfs(int x, int p = -1) {
	par[x][0] = p;
	for(int i = 1; i < LOG; i++) par[x][i] = par[par[x][i - 1]][i - 1];
	if(p != -1) 
		h[x] = h[p] + 1;
	sub[x] = 1;
	for(auto j : adj[x]) {
		if(j == p) continue;
		dfs(j, x);
		sub[x] += sub[j];
	}
}
void hld(int x, int p = -1, int hi = 0) {
	link[tim] = x;
	st[x] = tim++;
	head[x] = hi;

	int bz = -1;
	for(auto j : adj[x]) {
		if(j == p) continue;
		if(bz == -1 || sub[bz] < sub[j]) bz = j;
	}
	if(bz == -1) {
		fn[x] = tim;
		return;
	}

	hld(bz, x, hi);
	for(auto j : adj[x]) {
		if(j == p || j == bz) continue;
		hld(j, x, j);
	}
	fn[x] = tim;
}

int get_par(int x, int y) {
	for(int i = 0; i < LOG; i++)
		if((y >> i) & 1)
			x = par[x][i];
	return x;
}
int lca(int x, int y) {
	if(h[x] > h[y]) swap(x, y);
	y = get_par(y, h[y] - h[x]);
	if(x == y) return x;
	for(int i = LOG - 1; i >= 0; i--)
		if(par[x][i] != par[y][i])
			x = par[x][i], y = par[y][i];
	return par[x][0];
}

int find(int x) {
	if(pv[x] == x) return x;
	return pv[x] = find(pv[x]);
}
void merge(int x, int y) {
//	cout<<"^^" <<x <<' ' <<y <<endl;
	int X = find(x), Y = find(y);
	if(X == Y) return;
	if(sz[X] < sz[Y]) swap(X, Y);
	pv[Y] = X, sz[X] += sz[Y];
}

void build(int l = 0, int r = n, int v = 1) {
	seg[v] = -1;
	if(r - l < 2) {
		return;
	}
	int mid = (l + r) >> 1;
	build(l, mid, 2 * v), build(mid, r, 2 * v + 1);
	total = max(total, v);
}
void upd(int s, int e, int val, int l = 0, int r = n, int v = 1) {
	if(r <= s || l >= e) return;
	if(l >= s && r <= e) {
		if(seg[v] == -1) seg[v] = val;
		else merge(seg[v], val);
		return;
	}
	int mid = (l + r) >> 1;
	upd(s, e, val, l, mid, 2 * v), upd(s, e, val, mid, r, 2 * v + 1);
}
void relax(int l = 0, int r = n, int v = 1) {
	if(r - l < 2) {
		if(seg[v] != -1) merge(seg[v], s[link[l]]);
		return;
	}
	int mid = (l + r) >> 1;
	if(seg[v] != -1) {
		if(seg[2 * v] == -1) seg[2 * v] = seg[v];
		else merge(seg[v], seg[2 * v]);

		if(seg[2 * v + 1] == -1) seg[2 * v + 1] = seg[v];
		else merge(seg[v], seg[2 * v + 1]);
	}
	relax(l, mid, 2 * v), relax(mid, r, 2 * v + 1);
}

bool cmp(int i, int j) {
	return st[i] < st[j];
}
void solve(int x) {
	vector<int> ver;
	for(auto i : vc[x]) ver.push_back(i);
	sort(All(ver), cmp);
	int sz = Sz(ver);
	for(int i = 1; i < sz; i++) {
		ver.push_back(lca(ver[i - 1], ver[i]));
	}
	sort(All(ver), cmp);
	ver.erase(unique(All(ver)), ver.end());

	stack<int> mt;
	mt.push(ver[0]);

	upd(st[ver[0]], st[ver[0]] + 1, x);
	for(int i = 1; i < Sz(ver); i++) {
		int v = ver[i];
		while(fn[mt.top()] < fn[v]) mt.pop();
		int p = mt.top();
	
		upd(st[v], st[v] + 1, x);
	
		int u = v;
		while(u != -1 && h[u] >= h[p]) {
			if(h[head[u]] < h[p]) break;
			upd(st[head[u]], st[u] + 1, x);
			u = par[head[u]][0];
		}
		if(h[u] >= h[p]) {
			upd(st[p], st[u] + 1, x);
		}

		mt.push(v);
	}
	ver.clear();
	mt.clear();
}

int main() {
	ios :: sync_with_stdio(0), cin.tie(0); cout.tie(0);

	for(int i = 0; i < MAXN; i++) sz[i] = 1, pv[i] = i;

	cin >>n >>k;
	for(int i = 0; i < n - 1; i++) {
		int u, v; cin >>u >>v;
		--u, --v;
		adj[u].push_back(v), adj[v].push_back(u);
	}
	dfs(0);
	hld(0);
		
	for(int i = 0; i < n; i++) {
		cin >>s[i];
	}
	
	build();
	
	for(int i = 0; i < n; i++) {
		--s[i];
		vc[s[i]].push_back(i);
	}

	for(int i = 0; i < k; i++) {
		solve(i);
	}

	relax();

	for(int i = 0; i < MAXN; i++) pv[i] = find(pv[i]);

	for(int i = 0; i < n; i++) {
		int u = s[i];
		for(auto j : adj[i]) {
			int v = s[j];
			if(pv[u] == pv[v]) continue;
			nei[pv[u]].push_back(pv[v]);
		}
	}

	for(int i = 0; i < MAXN; i++) {
		nei[i].erase(unique(All(nei[i])), nei[i].end());
		if(Sz(nei[i]) == 1) ans++;
	}
	cout<<(ans + 1) / 2;

	return 0;
}

Compilation message

mergers.cpp: In function 'void solve(int)':
mergers.cpp:169:5: error: 'class std::stack<int>' has no member named 'clear'
  169 |  mt.clear();
      |     ^~~~~