This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include <bits/stdc++.h>
using namespace std;
#define int long long
#define INF (int)1e18
#define f first
#define s second
mt19937_64 RNG(chrono::steady_clock::now().time_since_epoch().count());
int n, k;
const int N = 2e5 + 69;
vector <int> adj[N], comp, upd;
int c[N], sz[N], tot[N];
bool del[N], need[N], has[N];
int cen;
int ans = INF;
int val;
void dfs(int u, int par = -1){
sz[u] = 1;
comp.push_back(u);
for (int v : adj[u]){
if (v != par && !del[v]){
dfs(v, u);
sz[u] += sz[v];
}
}
}
void dfs2(int u, int par = -1){
has[u] |= c[u] == c[cen];
val += c[u] == c[cen];
for (int v : adj[u]){
if (v != par && !del[v]){
dfs2(v, u);
has[u] |= has[v];
}
}
if (has[u]){
upd.push_back(c[u]);
need[c[u]] = true;
}
}
int find(int x){
comp.clear();
dfs(x);
for (auto u : comp){
int mx = 0;
for (int v : adj[u]){
if (!del[v] && sz[v] < sz[u]) mx = max(mx, sz[v]);
}
mx = max(mx, (int)comp.size() - 1 - sz[u]);
if (mx <= comp.size() / 2) return u;
}
assert(false);
}
void cd(int x){
x = find(x);
upd.clear();
val = 0;
cen = x;
for (auto u : comp) has[u] = false;
dfs2(x);
int cnt = 0;
for (auto u : upd){
if (need[u]){
need[u] = false;
cnt++;
}
}
if (val == tot[c[x]])
ans = min(ans, cnt - 1);
del[x] = true;
for (int v : adj[x]){
if (!del[v]){
cd(v);
}
}
}
void Solve()
{
cin >> n >> k;
for (int i = 1; i < n; i++) {
int u, v; cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
for (int i = 1; i <= n; i++) {
cin >> c[i];
tot[c[i]]++;
}
cd(1);
cout << ans << "\n";
}
int32_t main()
{
auto begin = std::chrono::high_resolution_clock::now();
ios_base::sync_with_stdio(0);
cin.tie(0);
int t = 1;
// cin >> t;
for(int i = 1; i <= t; i++)
{
//cout << "Case #" << i << ": ";
Solve();
}
auto end = std::chrono::high_resolution_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::nanoseconds>(end - begin);
cerr << "Time measured: " << elapsed.count() * 1e-9 << " seconds.\n";
return 0;
}
Compilation message (stderr)
capital_city.cpp: In function 'long long int find(long long int)':
capital_city.cpp:57:16: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<long long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
57 | if (mx <= comp.size() / 2) return u;
| ~~~^~~~~~~~~~~~~~~~~~
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |