This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include <bits/stdc++.h>
using namespace std;
struct DSU {
vector<int> dad, sz;
DSU(int n = 0) : dad(n, -1), sz(n, 1) {}
int Find(int x) {
if (dad[x] == -1) return x;
return dad[x] = Find(dad[x]);
}
bool Union(int x, int y) {
x = Find(x), y = Find(y);
if (x == y) return false;
if (sz[x] < sz[y]) swap(x, y);
dad[y] = x;
sz[x] += sz[y];
return true;
}
void Reset(int node) {
dad[node] = -1;
sz[node] = 1;
}
};
struct Solver {
int n, k, ans;
DSU dsu;
vector<bool> vis;
vector<vector<int>> adj, city;
vector<int> col, sz, nodes, dad, cnt;
void Main() {
cin >> n >> k;
adj.resize(n);
for (int i = 0; i < n - 1; ++i) {
int a, b; cin >> a >> b; --a, --b;
adj[a].emplace_back(b);
adj[b].emplace_back(a);
}
col.resize(n);
cnt.resize(k);
for (int i = 0; i < n; ++i) {
cin >> col[i]; --col[i];
++cnt[col[i]];
}
city.resize(k);
dad.resize(n);
vis.resize(n);
dsu = DSU(k);
sz.resize(n);
ans = n;
Centroid(0);
cout << ans << endl;
}
void GetSizes(int node, int par) {
sz[node] = 1;
nodes.emplace_back(node);
for (int &x : adj[node]) {
if (x == par || vis[x]) continue;
GetSizes(x, node);
sz[node] += sz[x];
}
}
int GetCentroid(int node) {
nodes.clear();
GetSizes(node, -1);
while (true) {
bool br = true;
for (int &x : adj[node]) {
if (!vis[x] && sz[x] < sz[node] && 2 * sz[x] > (int)nodes.size()) {
br = false;
node = x;
break;
}
}
if (br) break;
}
return node;
}
void GetDads(int node, int par) {
dad[node] = par;
for (int &x : adj[node]) {
if (x == par || vis[x]) continue;
GetDads(x, node);
}
}
void Centroid(int node) {
node = GetCentroid(node);
vis[node] = true;
GetDads(node, -1);
for (int &x : nodes) {
city[col[x]].emplace_back(x);
}
vector<int> q;
for (int &x : city[col[node]]) {
if (x != node)
q.emplace_back(x);
}
int cnt_merges = 0;
bool all_nodes_of_col = (int)city[col[node]].size() == cnt[col[node]];
for (int it = 0; it < (int)q.size(); ++it) {
int x = q[it];
all_nodes_of_col &= (int)city[col[x]].size() == cnt[col[x]];
if (dsu.Union(col[x], col[dad[x]])) {
++cnt_merges;
for (int &y : city[col[dad[x]]]) {
q.emplace_back(y);
}
}
}
if (all_nodes_of_col) {
ans = min(ans, cnt_merges);
}
for (int &x : nodes) {
dsu.Reset(col[x]);
city[col[x]].clear();
}
for (int &x : adj[node]) {
if (!vis[x]) {
Centroid(x);
}
}
}
};
int main() {
ios::sync_with_stdio(0);
cin.tie(0);
Solver solver;
solver.Main();
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |