This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include<bits/stdc++.h>
using namespace std;
using ll = long long;
#define all(x) (x).begin(), (x).end()
template<class T> struct ST {
static constexpr T ID = {(int) 1e9, 0}; // or whatever ID
inline T comb(T a, T b) { return min(a, b); } // or whatever function
int sz;
vector<T> t;
void init(int _sz, T val = ID) {
t.assign((sz = _sz) * 2, ID);
}
void init(vector<T> &v) {
t.resize((sz = v.size()) * 2);
for (int i = 0; i < sz; ++i)
t[i + sz] = v[i];
for (int i = sz - 1; i; --i)
t[i] = comb(t[i * 2], t[(i * 2) | 1]);
}
void upd(int i, T x) {
for (t[i += sz] = x; i > 1; i >>= 1)
t[i >> 1] = comb(t[i], t[i ^ 1]);
}
T query(int l, int r) {
T ql = ID, qr = ID;
for (l += sz, r += sz + 1; l < r; l >>= 1, r >>= 1) {
if (l & 1) ql = comb(ql, t[l++]);
if (r & 1) qr = comb(t[--r], qr);
}
return comb(ql, qr);
}
};
struct DSU {
vector<int> e;
DSU(int sz) { e = vector<int>(sz + 1, -1); }
int get(int x) { return e[x] < 0 ? x : e[x] = get(e[x]); }
bool same_set(int a, int b) { return get(a) == get(b); }
int size(int x) { return -e[get(x)]; }
bool unite(int x, int y) {
x = get(x), y = get(y);
if (x == y) return false;
if (e[x] > e[y]) swap(x, y);
e[x] += e[y]; e[y] = x;
return true;
}
};
const int N = 2e5 + 7, L = 20;
int n, k;
vector<int> adj[N], col[N], topo, comp;
set<int> active;
ST<array<int, 2>> st;
int c[N], lift[N][L], dep[N], sz[N], head[N], pos[N], start[N], who[N], timer;
int tin[N], tout[N], reach[N];
bool vis[N], done[N];
void dfs(int v = 1, int p = 0) {
tin[v] = ++timer;
sz[v] = 1;
lift[v][0] = p;
for (int i = 1; i < L; i++)
lift[v][i] = lift[lift[v][i - 1]][i - 1];
for (int &u : adj[v]) {
if (u != p) {
dep[u] = dep[v] + 1;
dfs(u, v);
sz[v] += sz[u];
if (adj[v][0] == p || sz[v] > sz[adj[v][0]])
swap(u, adj[v][0]);
}
}
tout[v] = timer;
}
void dfs_hld(int v = 1, int p = 0) {
pos[v] = timer++;
for (int u : adj[v]) {
if (u != p) {
head[u] = (u == adj[v][0] ? head[v] : u);
dfs_hld(u, v);
}
}
}
int lca(int u, int v) {
if (dep[u] > dep[v])
swap(u, v);
for (int i = L - 1; ~i; --i)
if (dep[v] - (1 << i) >= dep[u])
v = lift[v][i];
if (u == v)
return u;
for (int i = L - 1; ~i; --i)
if (lift[v][i] != lift[u][i])
v = lift[v][i], u = lift[u][i];
return lift[u][0];
}
void dfs1(int a);
void process(int a, int b) {
while (active.lower_bound(a) != active.end()) {
int x = *active.lower_bound(a);
if (x > b)
break;
active.erase(x);
if (!vis[who[x]])
dfs1(who[x]);
}
}
void query(int a, int b) {
for (; head[a] != head[b]; b = lift[head[b]][0]) {
if (dep[b] > reach[head[b]]) {
process(pos[head[b]], pos[b]);
reach[head[b]] = dep[b];
}
}
process(pos[a], pos[b]);
}
void dfs1(int a) {
vis[a] = 1;
for (int v : col[a])
query(start[a], v);
topo.push_back(a);
}
void dfs2(int a) {
vis[a] = 1;
comp.push_back(a);
for (int v : col[a]) {
while (st.query(tin[v], tout[v])[0] <= dep[v]) {
int x = st.query(tin[v], tout[v])[1];
st.upd(tin[x], {(int) 1e9, 0});
if (!vis[c[x]])
dfs2(c[x]);
}
}
}
int main() {
cin.tie(0)->sync_with_stdio(false);
cin >> n >> k;
for (int i = 1; i < n; i++) {
int u, v;
cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
for (int i = 1; i <= n; i++) {
cin >> c[i];
col[c[i]].push_back(i);
}
timer = 0;
dfs();
for (int i = 1; i <= k; i++) {
start[i] = col[i][0];
for (int x : col[i]) {
start[i] = lca(start[i], x);
}
}
timer = 1;
head[1] = 1;
dfs_hld();
for (int i = 1; i <= n; i++)
who[pos[i]] = c[i];
for (int i = 1; i <= n; i++)
active.insert(i);
memset(reach, -1, sizeof(reach));
for (int i = 1; i <= k; i++)
if (!vis[i])
dfs1(i);
reverse(all(topo));
st.init(n + 1);
memset(vis, 0, sizeof(vis));
memset(head, 0, sizeof(head));
for (int i = 1; i <= n; i++)
st.upd(tin[i], {dep[start[c[i]]], i});
DSU dsu(n + 1);
int ans = k - 1;
for (int v : topo) {
if (!vis[v]) {
comp.clear();
dfs2(v);
int cnt = 0;
for (int x : comp)
for (int u : col[x])
head[u] = v, cnt++;
for (int x : comp)
for (int s : col[x])
for (int t : adj[s])
if (head[t] == v)
dsu.unite(s, t);
if (dsu.size(col[comp[0]][0]) == cnt)
ans = min(ans, (int) comp.size() - 1);
}
}
cout << ans << '\n';
return 0;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |