// In the name of God
#include<bits/stdc++.h>
using namespace std;
#define pb push_back
#define fi first
#define se second
#define mp make_pair
typedef long long ll;
const int N = 2e5 + 5, lg = 20;
int n, k, a[N];
vector<int> adj[N], vec[N], G[N];
int up[N][lg], tin[N], tout[N], T, h[N];
int f[N], par[N];
void dfs(int v, int p) {
up[v][0] = p;
par[v] = p;
for (int i = 1; i < lg; i++) {
up[v][i] = up[up[v][i - 1]][i - 1];
}
tin[v] = ++T;
for (auto u : adj[v]) {
if (u != p) {
h[u] = h[v] + 1;
dfs(u, v);
}
}
tout[v] = ++T;
}
bool anc(int v, int u) {
return tin[v] <= tin[u] && tout[v] >= tout[u];
}
int lca(int v, int u) {
if (tin[v] > tin[u]) swap(v, u);
if (anc(v, u)) return v;
for (int i = lg - 1; i >= 0; i--) {
if (!anc(up[u][i], v)) {
u = up[u][i];
}
}
return up[u][0];
}
vector<int> V;
int out[N];
void dfs2(int v, int p) {
vector<int> dat;
while (!V.empty() && h[f[V.back()]] >= h[f[v]]) {
dat.pb(V.back());
V.pop_back();
}
V.pb(v);
int l = -1, r = V.size() - 1;
while (r - l > 1) {
int mid = (l + r) >> 1;
if (h[V[mid]] >= h[f[v]])
r = mid;
else l = mid;
}
if (a[v] != a[V[r]]) {
out[a[v]]++;
G[a[V[r]]].pb(a[v]);
}
for (auto u : adj[v]) {
if (u != p)
dfs2(u, v);
}
V.pop_back();
while (!dat.empty()) {
V.pb(dat.back());
dat.pop_back();
}
}
int F[N];
bool vis[N];
void dfs3(int v) {
vis[v] = true;
for (auto u : G[v]) {
if (vis[u]) continue;
if (h[F[u]] > h[F[v]])
F[u] = F[v];
dfs3(u);
}
}
int col[N];
void solve() {
cin >> n >> k;
for (int i = 0; i < n - 1; i++) {
int v, u; cin >> v >> u;
adj[v].pb(u), adj[u].pb(v);
}
for (int i = 1; i <= n; i++) {
cin >> a[i];
vec[a[i]].pb(i);
}
if (n > 20) return;
dfs(1, 1);
for (int i = 1; i <= k; i++) {
int u = vec[i].back();
for (auto v : vec[i]) {
u = lca(u, v);
}
for (auto v : vec[i])
f[v] = u;
F[i] = u;
}
//dfs2(1, 0);
for (int v = 1; v <= n; v++) {
int u = par[v], mx = v;
while (u != f[v] && v != f[v]) {
if (h[f[u]] < h[f[mx]]) {
mx = u;
}
u = par[u];
}
u = mx;
if (a[v] != a[u]) {
G[a[u]].pb(a[v]);
out[a[v]]++;
}
}
vector<pair<int, int> > ver;
for (int i = 1; i <= k; i++) {
if (out[i] == 0) {
ver.pb(mp(h[vec[i].back()], f[vec[i].back()]));
dfs3(i);
}
}
sort(ver.begin(), ver.end(), greater<pair<int, int> >());
int ans = 1e9;
for (auto [tmp, v] : ver) {
if (col[a[v]])
continue;
queue<int> q;
q.push(a[v]);
col[a[v]] = a[v];
bool ok = true;
int cnt = 0;
while (!q.empty()) {
int c = q.front();
q.pop();
if (col[c] != a[v]) {
ok = false;
break;
}
for (auto w : vec[c]) {
if (w == v) continue;
int u = par[w];
if (col[a[u]] == 0) {
col[a[u]] = a[v];
cnt++;
q.push(a[u]);
}
}
}
if (ok)
ans = min(ans, cnt);
}
cout << ans << '\n';
}
int32_t main() {
ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
int tc = 1; // cin >> tc;
while (tc--) {
solve();
}
return 0;
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
6 ms |
22108 KB |
Output is correct |
2 |
Correct |
6 ms |
22108 KB |
Output is correct |
3 |
Correct |
6 ms |
22104 KB |
Output is correct |
4 |
Incorrect |
6 ms |
22108 KB |
Output isn't correct |
5 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
6 ms |
22108 KB |
Output is correct |
2 |
Correct |
6 ms |
22108 KB |
Output is correct |
3 |
Correct |
6 ms |
22104 KB |
Output is correct |
4 |
Incorrect |
6 ms |
22108 KB |
Output isn't correct |
5 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
90 ms |
27988 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
6 ms |
22108 KB |
Output is correct |
2 |
Correct |
6 ms |
22108 KB |
Output is correct |
3 |
Correct |
6 ms |
22104 KB |
Output is correct |
4 |
Incorrect |
6 ms |
22108 KB |
Output isn't correct |
5 |
Halted |
0 ms |
0 KB |
- |