Submission #102380

#TimeUsernameProblemLanguageResultExecution timeMemory
102380Dat160601Mergers (JOI19_mergers)C++17
70 / 100
3041 ms208220 KiB
#include <bits/stdc++.h>
using namespace std;
#define mp make_pair
#define pb push_back
#define fi first
#define se second

const int N = 5e5 + 7;

int n, k, u, v, s[N], level[N], sub[N], par[20][N], cnt = 0, pset[N], ret[N], sz[N], ans = 0;
bool vis[N], col[N];
vector <int> edge[N], tp[N], g[N], leaf;
map < pair <int, int>, int > ed;

void predfs(int u, int p){
	int ch = 0;
	for(int v : edge[u]){
		if(v == p) continue;
		par[0][v] = u;
		level[v] = level[u] + 1;
		predfs(v, u);
		ch++;
	}
	if(ch == 0) leaf.pb(u);
}

int lca(int u, int v){
	if(level[u] > level[v]) swap(u, v);
	for(int i = 19; i >= 0; i--) if(level[par[i][v]] >= level[u]) v = par[i][v];
	for(int i = 19; i >= 0; i--) if(par[i][u] != par[i][v]) u = par[i][u], v = par[i][v];
	if(u == v) return u;
	return par[0][u];
}

int fset(int x){
	if(pset[x] == x) return x;
	return pset[x] = fset(pset[x]);
}

void unionset(int u, int v){
	u = fset(u), v = fset(v);
	if(u == v) return;
	pset[u] = v;
	sz[v] += sz[u];
	sz[u] = 0;
}

void dfs(int u, int p){
	int child = 0;
	for(int v : g[u]){
		if(v == p) continue;
		child ++;
		dfs(v, u);
	}
	if(child == 0 || (child == 1 && u == 1)) ans++;
}

int main(){
	ios_base::sync_with_stdio(0);
	cin >> n >> k;
	for(int i = 1; i < n; i++){
		cin >> u >> v;
		edge[u].pb(v);
		edge[v].pb(u);
	}
	for(int i = 1; i <= n; i++){
		cin >> s[i];
		tp[s[i]].pb(i);
	}
	par[0][1] = 1;
	level[1] = 1;
	predfs(1, 1);
	for(int i = 1; i <= 19; i++){
		for(int j = 1; j <= n; j++){
			par[i][j] = par[i - 1][par[i - 1][j]];
		}
	}
	for(int i = 1; i <= k; i++){
		pset[i] = i;
		sz[i] = 1;
		if(tp[i].empty()) continue;
		int cur = tp[i][0];
		for(int j = 1; j < (int)tp[i].size(); j++){
			cur = lca(cur, tp[i][j]);
		}
		sub[cur]++;
	}
	priority_queue < pair <int, int> > q;
	for(int x : leaf) vis[x] = true, q.push(mp(level[x], x));
	while(!q.empty()){
		int u = q.top().se;
		q.pop();
		int p = fset(s[u]);
		sz[p] -= sub[u];
		if(!vis[par[0][u]]){
			vis[par[0][u]] = true;
			q.push(mp(level[par[0][u]], par[0][u]));
		}
		if(sz[fset(p)] == 0) continue;
		unionset(p, s[par[0][u]]);
	}
	for(int i = 1; i <= k; i++){
		if(pset[i] == i) ret[i] = ++cnt;
	}
	for(int i = 1; i <= k; i++){
		if(pset[i] != i) ret[i] = ret[fset(i)];
	}
	for(int i = 1; i <= n; i++){
		int id = ret[s[i]];
		for(int j : edge[i]){
			int jd = ret[s[j]];
			if(id == jd) continue;
			if(ed.count(mp(id, jd))) continue;
			g[id].pb(jd);
			g[jd].pb(id);
			ed[mp(id, jd)]++;
			ed[mp(jd, id)]++;
		}
	}
	dfs(1, 1);
	if(ans == 1) cout << 0;
	else cout << (ans + 1) / 2;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...