제출 #258053

#제출 시각아이디문제언어결과실행 시간메모리
258053AMO5Capital City (JOI20_capital_city)C++17
11 / 100
3050 ms50012 KiB
#include <bits/stdc++.h>

using namespace std;

#define fi first
#define se second
#define eb emplace_back
#define mt make_tuple
#define all(x) (x).begin(), (x).end()  
#define sz(x) int(x.size()) 
#define MOD 1000000007

typedef long long ll;
typedef pair <int, int> ii;
typedef pair <ll, ll> pll;
typedef vector<int> vi;
typedef vector<ll> vll;
typedef long double ld;

const ll INF=LLONG_MAX;
const int mxn=2e5+5;
bool DEBUG=0;

int n,k,city[mxn],siz[mxn],h[mxn],par[mxn],ans;
vi adj[mxn],adj2[mxn];
set<int>nodes,mcity,usedcity;
//required to vis, merged city, visited city
int citycnt[mxn];
bool vis[mxn];

struct DSU{
	vi par;
	DSU(int n){
		for(int i=0; i<n; i++){
			par.eb(i);
		}
	}
	void reset(int n){
		par.resize(n);
		for(int i=0; i<n; i++){
			par[i]=i;
		}
	}
	void resetnode(int u){
		par[u]=u;
	}
	int rt(int u){
		if(u!=par[u])par[u]=rt(par[u]);
		return par[u];
	}
	void merge(int u, int v){
		u=rt(u); v=rt(v);
		if(u==v)return;
		if(h[u]>h[v])swap(u,v);
		par[v]=u;
	}
	int getrt(int u){
		return rt(u);
	}
};
DSU dsu(1);

void dfs_sz(int u, int p=-1){
	siz[u]=1;
	citycnt[city[u]]++;
	usedcity.insert(city[u]);
	for(int v:adj[u]){
		if(v==p||vis[v])continue;
		dfs_sz(v,u);
		siz[u]+=siz[v];
	}
	return;
}

int centroid(int u, int p, int sizz){
	for(int v:adj[u]){
		if(v==p||vis[v])continue;
		if(siz[v]*2>sizz)return centroid(v,u,sizz);
	}
	return u;
}

void dfs(int u, int p){
	nodes.insert(u);
	for(int v:adj[u]){
		if(v==p||vis[v])continue;
		par[v]=u;
		h[v]=h[u]+1;
		dfs(v,u);
	}
	return;
}

void ctsolve(int ct){
	//assuming capital city is centroid
	par[ct]=-1;
	h[ct]=0;
	dfs(ct,-1);
	int City = city[ct];
	set<int>ncity;
	//cities need to visit 
	ncity.emplace(City);
	bool ok=1;
	if(sz(adj2[City])!=citycnt[City]){
		ok=0; ncity.erase(City);
	}
	while(ncity.size()&&ok){
		int u = *ncity.begin();
		ncity.erase(u);
		mcity.insert(u);
		for(int v:adj2[u]){
			while(v!=ct){
				v=dsu.rt(v);
				if(v==ct)break;
				if(ncity.find(city[par[v]])==ncity.end()&&mcity.find(city[par[v]])==mcity.end()){
					if(sz(adj2[city[par[v]]])==citycnt[city[par[v]]])ncity.insert(city[par[v]]);
					else{
						ok=0;
						break;
					}
				}
				dsu.merge(par[v],v);
				v=par[v];
			}
			if(!ok)break;
		}
	}
	for(int x:nodes){
		dsu.resetnode(x);
	}
	for(int x:usedcity)citycnt[x]=0;
	if(ok)ans = min(ans,sz(mcity));
	nodes=mcity=usedcity=set<int>();
	return;
}

void solve(int u=0){
	dfs_sz(u);
	int ct = centroid(u,-1,siz[u]);
	ctsolve(ct);
	vis[ct]=1;
	for(int v:adj[ct]){
		if(!vis[v]){
			solve(v);
		}
	}
}

int main()
{
    ios_base::sync_with_stdio(0); cin.tie(0);
    //freopen("input.txt","r",stdin); freopen("output.txt","w",stdout);
	cin >> n >> k;
	ans=k;
	dsu.reset(n);
	for(int i=0; i<n-1; i++){
		int u,v;
		cin >> u >> v;
		u--; v--;
		adj[u].eb(v);
		adj[v].eb(u);
	}
	for(int i=0; i<n; i++){
		cin>>city[i];
		city[i]--;
		adj2[city[i]].eb(i);
	}
	for(int i=0; i<k; i++){
		if(sz(adj2[i])==1){
			cout << 0 << '\n';
			return 0;
		}
	}
	solve(0);
	cout << ans-1 << '\n';
	return 0;
}
	
// READ & UNDERSTAND
// ll, int overflow, array bounds, memset(0)
// special cases (n=1?), n+1 (1-index)
// do smth instead of nothing & stay organized
// WRITE STUFF DOWN
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...