제출 #924063

#제출 시각아이디문제언어결과실행 시간메모리
924063KiaRez수도 (JOI20_capital_city)C++17
100 / 100
636 ms43044 KiB
/*
    IN THE NAME OF GOD
*/
#include <bits/stdc++.h>

// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
// #pragma GCC optimize("O3")
// #pragma GCC optimize("unroll-loops")

using namespace std;

typedef long long ll;
typedef pair<ll, ll> pll;
typedef pair<int, int> pii;
typedef long double ld;

#define F                                      first
#define S                                      second
#define Mp                                     make_pair
#define pb                                     push_back
#define pf                                     push_front
#define size(x)                                ((ll)x.size())
#define all(x)                                 (x).begin(),(x).end()
#define kill(x)		                           cout << x << '\n', exit(0);
#define fuck(x)                                cout << "(" << #x << " , " << x << ")" << endl
#define endl                                   '\n'

const int N = 2e5+23, lg = 18;
ll Mod = 1e9+7; //998244353;

inline ll MOD(ll a, ll mod=Mod) {a%=mod; (a<0)&&(a+=mod); return a;}
inline ll poww(ll a, ll b, ll mod=Mod) {
    ll ans = 1;
    a=MOD(a, mod);
    while (b) {
        if (b & 1) ans = MOD(ans*a, mod);
        b >>= 1;
        a = MOD(a*a, mod);
    }
    return ans;
}

int n, k, ans, glob, par[N], cnt[N], vis[N], c[N], sz[N], mark[N];
vector<int> adj[N], g[N];

void init(int v, int p=0, int f=0) {
	if(f==1) cnt[c[v]] ++;
	sz[v] = 1;
	par[v] = p;
	for(int u : adj[v]) {
		if(u == p || mark[u] == 1) continue;
		init(u, v, f);
		sz[v] += sz[u];
	}
}

int findC(int v, int p=0) {
	int res=-1, yey=(glob-sz[v] <= glob/2 ? 1 : 0);
	for(int u : adj[v]) {
		if(u == p || mark[u] == 1) continue;
		res = max(res, findC(u, v));
		yey &= (sz[u] <= glob/2);
	}
	if(yey==1) res = v;
	return res;
}

void clen(int v, int p=0) {
	cnt[c[v]] = 0, par[v] = 0, sz[v] = 0;
	vis[c[v]] = 0;
	for(int u : adj[v]) {
		if(u == p || mark[u] == 1) continue;
		clen(u, v);
	}
}

void calc(int v) {
	init(v);
	glob = sz[v];
	v = findC(v);
	init(v, 0, 1);
	
	int res = 0;
	queue<int> q;
	q.push(c[v]);
	vis[c[v]] = 1;
	while(size(q) > 0) {
		int col = q.front();
		q.pop();
		if(cnt[col] < size(g[col])) {
			res = k-1;
			break;
		}
		for(auto u : g[col]) {
			if(par[u]!=0 && vis[c[par[u]]]==0) {
				vis[c[par[u]]] = 1;
				q.push(c[par[u]]);
				res++;
			}
		}
	}
	clen(v);
	mark[v]=1;
	ans = min(ans, res);
	for(int u : adj[v]) {
		if(mark[u] == 0) {
			calc(u);
		}
	}
}

int main () {
	ios_base::sync_with_stdio(false), cin.tie(0);

	cin>>n>>k; ans = k-1;
	for(int v,u,i=1; i<n; i++) {
		cin>>v>>u;
		adj[v].pb(u);
		adj[u].pb(v);
	}
	for(int i=1; i<=n; i++) {
		cin>>c[i];
		g[c[i]].pb(i);
	}

	calc(1);

	cout<<ans<<endl;

	return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...