Submission #852832

#TimeUsernameProblemLanguageResultExecution timeMemory
852832antonMergers (JOI19_mergers)C++17
100 / 100
1078 ms161220 KiB
#include<bits/stdc++.h>

using namespace std;
#define int long long 
#define pii pair<int, int>

int n, k;
const int MAX_N = 500'000;
const int MAX_K = 500'000;

const int INF = 1e18;

int res = 0;

vector<int> adj[MAX_N];
vector<int> mem[MAX_N];
int group[MAX_N];
int anc[MAX_N];

int group_lca[MAX_K];

int depth[MAX_N];



void dfs(int u, int a, int d){
    //cout<<"dfs "<<u<<endl;
    anc[u] = a;
    depth[u] = d;
    for(auto v: adj[u]){
        if(v!=a){
            dfs(v, u, d+1);
        }
    }
}

int bl[19][MAX_N];

void build_bl(){
    for(int i = 0; i<n; i++){
        bl[0][i] =anc[i];
    }

    for(int i = 1; i<19; i++){
        for(int j = 0; j<n; j++){
            bl[i][j] = bl[i-1][bl[i-1][j]];
        }
    }
}

int lca(int a, int b){

    if(a==-1){
        return b;
    }
    if(b==-1){
        return a;
    }

    if(depth[a]>depth[b]){
        return lca(b, a);
    }
    for(int step = 18; step>=0; step--){
        if(depth[bl[step][b]]>=depth[a]){
            b = bl[step][b];
        }
    }

    if(a==b){
        return a;
    }

    for(int step = 18; step>=0; step--){
        if(bl[step][a] != bl[step][b]){
            a = bl[step][a];
            b=  bl[step][b];
        }
    }

    return bl[0][a]; 
}

bool is_free[MAX_N];

int nbf = 0;
int dfs2(int u){
    int min_d = INF;

    min_d= min(min_d, depth[group_lca[group[u]]]);

    for(auto v: adj[u]){
        if(v!= anc[u]){
            min_d= min(min_d, dfs2(v));
        }
    }

    if(min_d >= depth[u]){
        is_free[u] = true;
        nbf++;
    }
    return min_d;
}

pii dfs3(int u){
    pii s=  pii(0, 0);

    for(auto v: adj[u]){
        if(v!=anc[u]){
            auto r=dfs3(v);
            s.first += r.first;
            s.second += r.second;
        }
    }

    if((s.second == nbf-1|| s.first==0) && is_free[u]){
        res++;
    }
    //cout<<3<<" "<<u<<" "<<s<<endl;

    if(is_free[u]){
        s.first = 1;
        s.second++;
    }
    return s;
}

signed main(){
    cin>>n>>k;

    fill(group_lca, group_lca + MAX_K, -1);

    for(int i = 0; i<n-1; i++){
        int a, b;
        cin>>a>>b;
        adj[a-1].push_back(b-1);
        adj[b-1].push_back(a-1);
    }

    dfs(0, 0, 0);

    build_bl();

    for(int i = 0; i<n; i++){
        cin>>group[i];
        group[i]--;
    }

    for(int i = 0; i<n; i++){
        group_lca[group[i]] = lca(group_lca[group[i]], i);
    }

    dfs2(0);
    is_free[0] = false;
    nbf--;
    dfs3(0);
    cout<<(res+1)/2<<endl;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...