Submission #763577

#TimeUsernameProblemLanguageResultExecution timeMemory
763577MISM06Mergers (JOI19_mergers)C++14
48 / 100
208 ms262144 KiB
//0 1 1 0 1
//0 1 0 0 1
//1 0 0 1 1
//0 1 1 0 1
#include <bits/stdc++.h>
#pragma GCC optimize("Ofast,unroll-loops")
#pragma GCC target("avx2")

using namespace std;

#define F 			first
#define S 			second
#define pb 			push_back
#define sze			size()
#define	all(x)		x.begin() , x.end()
#define wall__		cout << "--------------------------------------\n";
#define kids		int mid = (tl + tr) >> 1, cl = v << 1, cr = v << 1 | 1
#define file_io		freopen("input.cpp", "r", stdin); freopen("output.cpp", "w", stdout);

typedef long long ll;
typedef long double dl;
typedef pair < int , int > pii;
typedef pair < int , ll > pil;
typedef pair < ll , int > pli;
typedef pair < ll , ll > pll;
typedef pair < int , pii > piii;
typedef pair < ll, pll > plll;


const ll N = 1e5 + 10;
const ll mod = 1e9 + 7;
const ll inf = 2e16;
const ll INF = 1e9 + 10;
const ll lg = 32;

int n, k, a[N], timer, c[N], sub[N], cnt, lst = 0;
vector < int > g[N];
vector < pii > edg[N];
vector < int > mark[N];
int we[N];

void dfs (int v, int p) {
	sub[v] = 1; mark[v][a[v]] = 1; 
	for (auto u : g[v]) {
		if (u == p) continue;
		dfs(u, v);
		sub[v] += sub[u];
		for (int i = 1; i <= k; i++) mark[v][i] |= mark[u][i];
	}
	ll x = 0;
	for (int y = 1; y <= k; y++) {
		if (mark[v][y]) x += c[y];
	}
	int w = 0;
	if (x == sub[v]) {
		w = 1;
	}
	we[lst] = w;
	if (v != 1) {
		edg[v].pb({p, lst});
		edg[p].pb({v, lst});
	}
	++lst;
	cnt += w;
}

int h[N];
pii par[N];
void dfs0(int v, int p) {
	for (auto e : edg[v]) {
		int u = e.F;
		if (u == p) continue;
		h[u] = h[v] + we[e.S];
		dfs0(u, v);
	}
}
void dfs1(int v, int p, int id) {
	par[v] = {p, id};
	h[v] = h[p] + we[id];
	for (auto e : edg[v]) {
		int u = e.F;
		if (u == p) continue;
		dfs1(u, v, e.S);
	}
}

void solve () {

	cin >> n >> k;
	for (int i = 1; i < n; i++) {
		int v, u; cin >> v >> u;
		g[v].pb(u); g[u].pb(v);
	}
	for (int i = 1; i <= n; i++) {
		mark[i].resize(k + 1, 0);
	}
	for (int i = 1; i <= n; i++) {
		cin >> a[i]; ++c[a[i]];
	}
	dfs(1, 0);
	int ans = 0;
	--cnt;
	while(cnt) {
		fill(h, h + (n + 1), 0);
		dfs0(1, 0);
		int mx = 0;
		for (int i = 1; i <= n; i++) if (h[mx] < h[i]) mx = i;
		++ans;
		fill(h, h + (n + 1), 0);
		fill(par, par + (n + 1), make_pair(0, 0));
		dfs1(mx, 0, 0);
		int mx2 = 0;
		for (int i = 1; i <= n; i++) if (h[mx2] < h[i]) mx2 = i;
		while(mx2 != mx) {
			int id = par[mx2].S;
			cnt -= we[id];
			we[id] = 0;
			mx2 = par[mx2].F;
		}
	}
	cout << ans << '\n';


}


int main() {
	// ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
	int t = 1;
	// cin >> t;
	while (t--) {solve();}
    return 0;
}
/*
*/
//shrek will AC this;
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...