#include <bits/stdc++.h>
using namespace std;
vector<vector<int>> adj, parent;
vector<int> depth;
int timer = 0;
vector<int> tin;
void dfs(int node, int prev, int dis) {
    tin[node] = timer++;
    depth[node] = dis;
    parent[node][0] = prev;
    for(auto child: adj[node]) if (child != prev) dfs(child, node, dis+1);
}
int ancestor(int x, int k) {
    int cnt = 0;
    while(k && x) {
        if (k&1) x = parent[x][cnt];
        k = k>>1; cnt++;
    }
    return x;
}
int lca(int a, int b) {
    if (depth[a] < depth[b]) swap(a, b);
    a = ancestor(a, depth[a]-depth[b]);
    if (a==b) return a;
    for(int k=19; k>=0; k--) {
        int aa = parent[a][k];
        int bb = parent[b][k];
        if (aa != bb) {
            a = aa; b = bb;
        }
    }
    return parent[a][0];
}
vector<int> componentIds;
vector<vector<int>> adjComponents;
void dfsComponents(int node, int prev) {
    if (prev && componentIds[prev] != componentIds[node]) {
        adjComponents[componentIds[prev]].push_back(componentIds[node]);
        adjComponents[componentIds[node]].push_back(componentIds[prev]);
    }
    for(auto child: adj[node]) if (child != prev) {
        dfsComponents(child, node);
    }
}
struct DSU {
    vector<int> e;
    DSU(int N) { e = vector<int>(N, -1); }
    // get representive component (uses path compression)
    int get(int x) { return e[x] < 0 ? x : e[x] = get(e[x]); }
    bool same_set(int a, int b) { return get(a) == get(b); }
    int size(int x) { return -e[get(x)]; }
    bool unite(int x, int y) {  // union by size
        x = get(x), y = get(y);
        if (x == y) return false;
        if (e[x] > e[y]) swap(x, y);
        e[x] += e[y]; e[y] = x;
        return true;
    }
};
signed main() {
    ios::sync_with_stdio(0);
    cin.tie(0);
    int n, k; cin >> n >> k;
    adj = vector<vector<int>>(n+1);
    parent = vector<vector<int>>(n+1, vector<int>(20, 0));
    depth = vector<int>(n+1, 0);
    for(int i=0; i<n-1; i++) {
        int u, v; cin >> u >> v;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    tin = vector<int>(n+1, 0);
    dfs(1, 0, 0);
    for(int k=1; k<20; k++) {
        for(int i=1; i<=n; i++) {
            if (parent[i][k-1]) parent[i][k] = parent[parent[i][k-1]][k-1];
        }
    }
    DSU dsu(n+1);
    vector<vector<int>> st(k+1);
    for(int i=1; i<=n; i++) {
        int state; cin >> state;
        st[state].push_back(i);
    }
    for(int i=1; i<=k; i++) {
        // sort(st[i].begin(), st[i].end(), [&] (int x, int y) {
        //     return tin[x] < tin[y];
        // });
        for(int j=1; j<st[i].size(); j++) {
            int x = st[i][j-1]; int y = st[i][j];
            int anc = lca(x, y);
            while (x != anc) {
                dsu.unite(parent[x][0], x);
                x = parent[x][0];
            }
            while (y != anc) {
                dsu.unite(parent[y][0], y);
                y = parent[y][0];
            }
        }
    }
    componentIds = vector<int>(n+1);
    adjComponents = vector<vector<int>>(n+1);
    for(int i=1; i<=n; i++) {
        componentIds[i] = dsu.get(i);
    }
    dfsComponents(1, 0);
    int ans = 0;
    for(int i=1; i<=n; i++) {
        if (adjComponents[i].size() == 1) {
            ans++;
        }
    }
    ans = (ans + 1)/2;
    cout << ans << '\n';
    return 0;
}
| # | Verdict | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict | Execution time | Memory | Grader output | 
|---|
| Fetching results... |