Submission #1176265

#TimeUsernameProblemLanguageResultExecution timeMemory
1176265igofanMergers (JOI19_mergers)C++20
0 / 100
50 ms19388 KiB
#include <bits/stdc++.h>

using namespace std;

vector<vector<int>> adj, parent;
vector<int> depth;
void dfs(int node, int prev, int dis) {
    depth[node] = dis;
    parent[node][0] = prev;
    for(auto child: adj[node]) if (child != prev) dfs(child, node, dis+1);
}
int ancestor(int x, int k) {
    int cnt = 0;
    while(k && x) {
        if (k&1) x = parent[x][cnt];
        k = k>>1; cnt++;
    }
    return x;
}
int lca(int a, int b) {
    if (depth[a] < depth[b]) swap(a, b);
    a = ancestor(a, depth[a]-depth[b]);
    if (a==b) return a;
    for(int k=19; k>=0; k--) {
        int aa = parent[a][k];
        int bb = parent[b][k];
        if (aa != bb) {
            a = aa; b = bb;
        }
    }
    return parent[a][0];
}

signed main() {
    ios::sync_with_stdio(0);
    cin.tie(0);
    int n, k; cin >> n >> k;
    adj = vector<vector<int>>(n+1);
    parent = vector<vector<int>>(n+1, vector<int>(20, 0));
    depth = vector<int>(n+1, 0);
    for(int i=0; i<n-1; i++) {
        int u, v; cin >> u >> v;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    dfs(1, 0, 0);
    for(int k=1; k<20; k++) {
        for(int i=1; i<=n; i++) {
            if (parent[i][k-1]) parent[i][k] = parent[parent[i][k-1]][k-1];
        }
    }

    vector<int> states(n+1);
    vector<bool> stateVisited(k+1, false), nodeVisited(n+1, false);
    vector<vector<int>> st(k+1);
    for(int i=1; i<=n; i++) {
        cin >> states[i];
        st[states[i]].push_back(i);
    }
    int ans = 0;
    for(int i=1; i<=k; i++) {
        if (stateVisited[i]) continue;
        stateVisited[i] = true;
        vector<int> queue;
        for(auto node: st[i]) {
            if (nodeVisited[node]) continue;
            nodeVisited[node] = true;
            queue.push_back(node);
        }
        int anc = 0, idx = 0;
        while(idx < queue.size()) {
            int queueSize = queue.size();
            for(int j=idx; j<queueSize; j++) {
                int node = queue[j];
                if (anc == 0) anc = node;
                else anc = lca(anc, node);
            }
            while(idx < queueSize) {
                int node = queue[idx++];
                while(node != anc) {
                    if (!stateVisited[states[node]]) {
                        stateVisited[states[node]] = true;
                        for(auto child: st[states[node]]) {
                            if (nodeVisited[child]) continue;
                            nodeVisited[child] = true;
                            queue.push_back(child);
                        }
                    }
                    node = parent[node][0];
                }
            }
            if (!stateVisited[states[anc]]) {
                stateVisited[states[anc]] = true;
                for(auto child: st[states[anc]]) {
                    if (nodeVisited[child]) continue;
                    nodeVisited[child] = true;
                    queue.push_back(child);
                }
            }
        }
        for(auto node: queue) {
            if (adj[node].size() == 1) {
                ans++;
                break;
            }
        }
    }
    ans /= 2;
    cout << ans << '\n';
    return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...