#include <bits/stdc++.h>
#define FOR(i,x,n) for(int i=x; i<n; i++)
#define F0R(i,n) FOR(i,0,n)
#define ROF(i,x,n) for(int i=n-1; i>=x; i--)
#define R0F(i,n) ROF(i,0,n)
#define WTF cout << "WTF" << endl
#define IOS ios::sync_with_stdio(false); cin.tie(0)
#define F first
#define S second
#define pb push_back
#define ALL(x) x.begin(), x.end()
#define RALL(x) x.rbegin(), x.rend()
using namespace std;
typedef long long LL;
typedef pair<int, int> PII;
typedef pair<LL, LL> PLL;
typedef vector<int> VI;
typedef vector<LL> VLL;
typedef vector<PII> VPII;
typedef vector<PLL> VPLL;
const int N = 2e5 + 7;
const int ALPHA = 27;
const int INF = 1e9 + 7;
const int MOD = 1e9 + 7;
const int LOG = 22;
int n, k, vcnt;
int eset[N][2], ns[N];
int comp[N], par[N], sz[N];
bool valid[N], vis[N];
VI adj[N], rs, group[N];
void init() {
ifstream cin;
cin.open("input.txt");
cin >> n >> k;
F0R(i, n - 1) {
cin >> eset[i][0] >> eset[i][1];
eset[i][0]--; eset[i][1]--;
}
F0R(i, n) {
cin >> ns[i];
ns[i]--;
group[ ns[i] ].pb(i);
}
fill(valid, valid + k, 1);
vcnt = k;
return;
}
void preD(int now, int c) {
vis[now] = 1;
comp[now] = c;
for(int on : adj[now]) if (!vis[on]) preD(on, c);
return;
}
bool buildGraph() {
F0R(i, n) adj[i].clear();
rs.clear();
F0R(i, n - 1) {
if (valid[ ns[ eset[i][0] ] ] && valid[ ns[ eset[i][1] ] ]) {
adj[ eset[i][0] ].pb(eset[i][1]);
adj[ eset[i][1] ].pb(eset[i][0]);
}
}
memset(vis, 0, sizeof vis);
int c = 0;
F0R(i, n) if (!vis[i]) {
rs.pb(i);
preD(i, c);
c++;
}
F0R(i, k) {
FOR(j, 1, group[i].size()) {
if (valid[i] && comp[ group[i][j] ] != comp[ group[i][j - 1] ]) {
vcnt--;
valid[i] = 0;
}
}
}
return 1;
}
int getSz(int now, int p) {
sz[now] = 1;
for(int on : adj[now]) if (on != p) sz[now] += getSz(on, now);
return sz[now];
}
int getCent(int now, int p, int s) {
for(int on : adj[now]) if (on != p)
if (sz[on] > (s / 2)) return getCent(on, now, s);
return now;
}
void getPar(int now, int p) {
par[now] = p;
for(int on : adj[now]) if (on != p) getPar(on, now);
return;
}
int solve(int root) {
getSz(root, -1);
int c = getCent(root, -1, sz[root]);
getPar(c, -1);
int ret = 0;
if (!valid[ ns[c] ]) return INF;
queue<int> keep;
keep.push(ns[c]);
assert(!vis[ ns[c] ]);
vis[ ns[c] ] = 1;
while(!keep.empty()) {
int now = keep.front();
keep.pop();
for(int on : group[now]) {
int p = par[on];
if (p == -1) continue;
if (!valid[ ns[p] ]) return INF;
if (!vis[ ns[p] ]) {
keep.push(ns[p]);
vis[ ns[p] ] = 1;
ret++;
}
}
}
if (valid[ ns[c] ]) vcnt--;
valid[ ns[c] ] = 0;
return ret;
}
int main() {
IOS;
init();
int ans = INF;
int cc = 0;
while(buildGraph() && vcnt) {
cc++;
memset(vis, 0, sizeof vis);
for(const int &on : rs)
ans = min(ans, solve(on));
}
cerr << cc << endl;
cout << ans;
}
Compilation message
capital_city.cpp: In function 'bool buildGraph()':
capital_city.cpp:3:35: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
3 | #define FOR(i,x,n) for(int i=x; i<n; i++)
......
92 | FOR(j, 1, group[i].size()) {
| ~~~~~~~~~~~~~~~~~~~~~
capital_city.cpp:92:9: note: in expansion of macro 'FOR'
92 | FOR(j, 1, group[i].size()) {
| ^~~
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Incorrect |
5 ms |
9812 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Incorrect |
5 ms |
9812 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Incorrect |
5 ms |
9812 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Incorrect |
5 ms |
9812 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |