제출 #720262

#제출 시각아이디문제언어결과실행 시간메모리
720262PoPularPlusPlus수도 (JOI20_capital_city)C++17
100 / 100
756 ms41980 KiB
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define ull unsigned long long 
#define pb(e) push_back(e)
#define sv(a) sort(a.begin(),a.end())
#define sa(a,n) sort(a,a+n)
#define mp(a,b) make_pair(a,b)
#define vf first
#define vs second
#define ar array
#define all(x) x.begin(),x.end()
const int inf = 0x3f3f3f3f;
const int mod = 1000000007; 
const double PI=3.14159265358979323846264338327950288419716939937510582097494459230;

mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());

bool remender(ll a , ll b){return a%b;}

//freopen("problemname.in", "r", stdin);
//freopen("problemname.out", "w", stdout);

const int N = 200002;
int n , k;
vector<int> adj[N];
int col[N];
bool vis[N],vis_col[N];
vector<int> v[N];
int sub[N];
int ans;
int p[N];
int done[N];

void pre(int node , int par){
	sub[node] = 1;
	for(int i : adj[node]){
		if(!vis[i] && i != par){
			pre(i , node);
			sub[node] += sub[i];
		}
	}
}

int find(int node , int par , int half){
	for(int i : adj[node]){
		if(!vis[i] && i != par && sub[i] > half){
			return find(i , node , half);
		}
	}
	return node;
}

int counting(int node , int par , int c , int centroid){
	p[node] = par;
	done[node] = centroid;
	int cnt = 0;
	for(int i : adj[node]){
		if(!vis[i] && i != par){
			cnt += counting(i , node , c , centroid);
		}
	}
	if(col[node] == c)cnt++;
	return cnt;
}

void dfs(int node , int par){
	done[node] = 0;
	for(int i : adj[node]){
		if(!vis[i] && i != par){
			dfs(i , node);
		}
	}
}

void centroid_decomposition(int node){
	pre(node , node);
	int centroid = find(node , node , sub[node]/2);
	int c = col[centroid];
	int cnt = counting(centroid,centroid,c ,-centroid);
	if(cnt == v[c].size()){
		int res = 0;
		queue<int> q;
		for(int i : v[c]){
			done[i] = 1;
			q.push(i);
		}
		while(q.size()){
			node = q.front();
			q.pop();
			if(node == centroid)continue;
			int par = p[node];
			if(done[v[col[par]][0]] == -centroid){
				if(vis_col[col[par]]){
					res = -1;
					break;
				}
				res++;
				for(int i : v[col[par]]){
					if(done[i] != -centroid){
						res = -1;
						break;
					}
					done[i] = 1;
					q.push(i);
				}
			}
			else if(done[v[col[par]][0]] == 0){
				res = -1;
				break;
			}
		}
		if(res != -1){
			ans = min(ans , res);
		}
	}
	dfs(centroid , centroid);
	vis_col[c] = 1;
	vis[centroid] = 1;
	for(int i : adj[centroid]){
		if(!vis[i]){
			centroid_decomposition(i);
		}
	}
}

void solve(){
	cin >> n;
	cin >> k;
	for(int i = 0; i < n-1; i++){
		int a , b;
		cin >> a >> b;
		adj[a].pb(b);
		adj[b].pb(a);
	}
	for(int i = 1; i <= n; i++){
		cin >> col[i];
		v[col[i]].pb(i);
	}
	memset(vis,0,sizeof vis);
	memset(vis_col,0,sizeof vis_col);
	memset(done,0,sizeof done);
	ans = k-1;
	centroid_decomposition(1);
	cout << ans << '\n';
}

int main(){
ios_base::sync_with_stdio(false);
cin.tie(NULL);
	//int t;cin >> t;while(t--)
	solve();
	return 0;
}

컴파일 시 표준 에러 (stderr) 메시지

capital_city.cpp: In function 'void centroid_decomposition(int)':
capital_city.cpp:81:9: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
   81 |  if(cnt == v[c].size()){
      |     ~~~~^~~~~~~~~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...