Submission #128226

#TimeUsernameProblemLanguageResultExecution timeMemory
128226Osama_AlkhodairyMergers (JOI19_mergers)C++17
100 / 100
1614 ms132924 KiB
#include <bits/stdc++.h>
using namespace std;
#define finish(x) return cout << x << endl, 0
#define ll long long

const int N = 500001;

int n, k, s[N], sum[N], dep[N], dfstime[N], deg[N], p[N][20];
vector <int> v[N], h[N];
int ct;

void dfsz(int node, int pnode){
    dfstime[node] = ct++;
    dep[node] = dep[pnode] + 1;
    p[node][0] = pnode;
    for(int i = 1 ; i < 20 ; i++){
        p[node][i] = p[p[node][i - 1]][i - 1];
    }
    for(auto &i : v[node]){
        if(i == pnode) continue;
        dfsz(i, node);
    }
}
int lift(int node, int k){
    for(int i = 19 ; i >= 0 ; i--){
        if(k & (1 << i)){
            node = p[node][i];
        }
    }
    return node;
}
int LCA(int a, int b){
    if(dep[a] > dep[b]) swap(a, b);
    b = lift(b, dep[b] - dep[a]);
    if(a == b) return a;
    for(int i = 19 ; i >= 0 ; i--){
        if(p[a][i] != p[b][i]){
            a = p[a][i];
            b = p[b][i];
        }
    }
    return p[a][0];
}
void add_path(int x, int y){
    sum[x]++;
    sum[y]++;
    sum[LCA(x, y)] -= 2;
}
void dfs(int node, int pnode){
    for(auto &i : v[node]){
        if(i == pnode) continue;
        dfs(i, node);
        sum[node] += sum[i];
    }
}
void dfs(int node, int pnode, int root){
    for(auto &i : v[node]){
        if(i == pnode) continue;
        if(sum[i] == 0){
            deg[root]++;
            deg[i]++;
            dfs(i, node, i);
        }
        else dfs(i, node, root);
    }
}
int main(){
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cin >> n >> k;
    for(int i = 0 ; i < n - 1 ; i++){
        int x, y;
        cin >> x >> y;
        v[x].push_back(y);
        v[y].push_back(x);
    }
    for(int i = 1 ; i <= n ; i++){
        cin >> s[i];
        h[s[i]].push_back(i);
    }
    dfsz(1, 0);
    for(int i = 1 ; i <= k ; i++){
        sort(h[i].begin(), h[i].end(), [&](int l, int r){
            return dfstime[l] < dfstime[r];
        });
        int nn = h[i].size();
        for(int j = 0 ; j < nn ; j++){
            add_path(h[i][j], h[i][(j + 1) % nn]);
        }
    }
    dfs(1, 0);
    dfs(1, 0, 1);
    int ans = 0;
    for(int i = 1 ; i <= n ; i++){
        ans += deg[i] == 1;
    }
    cout << (ans + 1) / 2 << endl;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...