Submission #1176396

#TimeUsernameProblemLanguageResultExecution timeMemory
1176396igofanMergers (JOI19_mergers)C++20
56 / 100
3094 ms32584 KiB
#include <bits/stdc++.h>

using namespace std;

vector<vector<int>> adj, parent;
vector<int> depth;
void dfs(int node, int prev, int dis) {
    depth[node] = dis;
    parent[node][0] = prev;
    for(auto child: adj[node]) if (child != prev) dfs(child, node, dis+1);
}
int ancestor(int x, int k) {
    int cnt = 0;
    while(k && x) {
        if (k&1) x = parent[x][cnt];
        k = k>>1; cnt++;
    }
    return x;
}
int lca(int a, int b) {
    if (depth[a] < depth[b]) swap(a, b);
    a = ancestor(a, depth[a]-depth[b]);
    if (a==b) return a;
    for(int k=19; k>=0; k--) {
        int aa = parent[a][k];
        int bb = parent[b][k];
        if (aa != bb) {
            a = aa; b = bb;
        }
    }
    return parent[a][0];
}

vector<int> componentIds;
vector<vector<int>> adjComponents;
void dfsComponents(int node, int prev) {
    if (prev && componentIds[prev] != componentIds[node]) {
        adjComponents[componentIds[prev]].push_back(componentIds[node]);
        adjComponents[componentIds[node]].push_back(componentIds[prev]);
    }
    for(auto child: adj[node]) if (child != prev) {
        dfsComponents(child, node);
    }
}

struct DSU {
    vector<int> e;
    DSU(int N) { e = vector<int>(N, -1); }
    // get representive component (uses path compression)
    int get(int x) { return e[x] < 0 ? x : e[x] = get(e[x]); }
    bool same_set(int a, int b) { return get(a) == get(b); }
    int size(int x) { return -e[get(x)]; }
    bool unite(int x, int y) {  // union by size
        x = get(x), y = get(y);
        if (x == y) return false;
        if (e[x] > e[y]) swap(x, y);
        e[x] += e[y]; e[y] = x;
        return true;
    }
};

signed main() {
    ios::sync_with_stdio(0);
    cin.tie(0);
    int n, k; cin >> n >> k;
    adj = vector<vector<int>>(n+1);
    parent = vector<vector<int>>(n+1, vector<int>(20, 0));
    depth = vector<int>(n+1, 0);
    for(int i=0; i<n-1; i++) {
        int u, v; cin >> u >> v;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    dfs(1, 0, 0);
    for(int k=1; k<20; k++) {
        for(int i=1; i<=n; i++) {
            if (parent[i][k-1]) parent[i][k] = parent[parent[i][k-1]][k-1];
        }
    }

    DSU dsu(n+1);
    vector<int> states(n+1);
    vector<bool> stateVisited(k+1, false), nodeVisited(n+1, false);
    vector<vector<int>> st(k+1);
    for(int i=1; i<=n; i++) {
        cin >> states[i];
        st[states[i]].push_back(i);
    }
    for(int i=1; i<=k; i++) {
        if (stateVisited[i]) continue;
        stateVisited[i] = true;
        vector<int> queue;
        for(auto node: st[i]) {
            if (nodeVisited[node]) continue;
            nodeVisited[node] = true;
            queue.push_back(node);
        }
        auto addState = [&](int state) {
            if (stateVisited[state]) return;
            stateVisited[state] = true;
            for(auto child: st[state]) {
                if (nodeVisited[child]) continue;
                nodeVisited[child] = true;
                queue.push_back(child);
            }
        };
        int anc = 0, idx = 0;
        while(idx < queue.size()) {
            int queueSize = queue.size();
            for(int j=idx; j<queueSize; j++) {
                int node = queue[j];
                if (anc == 0) anc = node;
                else anc = lca(anc, node);
            }
            while(idx < queueSize) {
                int node = queue[idx++];
                while(node != anc) {
                    addState(states[node]);
                    dsu.unite(node, parent[node][0]);
                    node = parent[node][0];
                }
            }
            addState(states[anc]);
        }
    }
    componentIds = vector<int>(n+1);
    adjComponents = vector<vector<int>>(n+1);
    for(int i=1; i<=n; i++) {
        componentIds[i] = dsu.get(i);
    }
    dfsComponents(1, 0);
    int ans = 0;
    for(int i=1; i<=n; i++) {
        if (adjComponents[i].size() == 1) {
            ans++;
        }
    }
    ans = (ans + 1)/2;
    cout << ans << '\n';
    return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...