Submission #1176409

#TimeUsernameProblemLanguageResultExecution timeMemory
1176409igofanMergers (JOI19_mergers)C++20
56 / 100
3093 ms32072 KiB
#include <bits/stdc++.h>

using namespace std;

vector<vector<int>> adj, parent;
vector<int> depth;
void dfs(int node, int prev, int dis) {
    depth[node] = dis;
    parent[node][0] = prev;
    for(auto child: adj[node]) if (child != prev) dfs(child, node, dis+1);
}
int ancestor(int x, int k) {
    int cnt = 0;
    while(k && x) {
        if (k&1) x = parent[x][cnt];
        k = k>>1; cnt++;
    }
    return x;
}
int lca(int a, int b) {
    if (depth[a] < depth[b]) swap(a, b);
    a = ancestor(a, depth[a]-depth[b]);
    if (a==b) return a;
    for(int k=19; k>=0; k--) {
        int aa = parent[a][k];
        int bb = parent[b][k];
        if (aa != bb) {
            a = aa; b = bb;
        }
    }
    return parent[a][0];
}

vector<int> componentIds;
vector<vector<int>> adjComponents;
void dfsComponents(int node, int prev) {
    if (prev && componentIds[prev] != componentIds[node]) {
        adjComponents[componentIds[prev]].push_back(componentIds[node]);
        adjComponents[componentIds[node]].push_back(componentIds[prev]);
    }
    for(auto child: adj[node]) if (child != prev) {
        dfsComponents(child, node);
    }
}

struct DSU {
    vector<int> e;
    DSU(int N) { e = vector<int>(N, -1); }
    // get representive component (uses path compression)
    int get(int x) { return e[x] < 0 ? x : e[x] = get(e[x]); }
    bool same_set(int a, int b) { return get(a) == get(b); }
    int size(int x) { return -e[get(x)]; }
    bool unite(int x, int y) {  // union by size
        x = get(x), y = get(y);
        if (x == y) return false;
        if (e[x] > e[y]) swap(x, y);
        e[x] += e[y]; e[y] = x;
        return true;
    }
};

signed main() {
    ios::sync_with_stdio(0);
    cin.tie(0);
    int n, k; cin >> n >> k;
    adj = vector<vector<int>>(n+1);
    parent = vector<vector<int>>(n+1, vector<int>(20, 0));
    depth = vector<int>(n+1, 0);
    for(int i=0; i<n-1; i++) {
        int u, v; cin >> u >> v;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    dfs(1, 0, 0);
    for(int k=1; k<20; k++) {
        for(int i=1; i<=n; i++) {
            if (parent[i][k-1]) parent[i][k] = parent[parent[i][k-1]][k-1];
        }
    }

    DSU dsu(n+1);
    vector<vector<int>> st(k+1);
    vector<bool> visited(n+1, false);
    for(int i=1; i<=n; i++) {
        int state; cin >> state;
        st[state].push_back(i);
    }
    for(int i=1; i<=k; i++) {
        int anc = st[i][0];
        for(int j=1; j<st[i].size(); j++) {
            anc = lca(anc, st[i][j]);
        }
        for(auto node: st[i]) {
            while (node != anc) {
                if (!visited[node]) {
                    visited[node] = true;
                    dsu.unite(node, parent[node][0]);
                }
                node = parent[node][0];
            }
        }
    }
    componentIds = vector<int>(n+1);
    adjComponents = vector<vector<int>>(n+1);
    for(int i=1; i<=n; i++) {
        componentIds[i] = dsu.get(i);
    }
    dfsComponents(1, 0);
    int ans = 0;
    for(int i=1; i<=n; i++) {
        if (adjComponents[i].size() == 1) {
            ans++;
        }
    }
    ans = (ans + 1)/2;
    cout << ans << '\n';
    return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...