제출 #563189

#제출 시각아이디문제언어결과실행 시간메모리
563189zaneyuMergers (JOI19_mergers)C++14
100 / 100
712 ms82368 KiB
/*input
2 2
1 2
1
2
*/
#include<bits/stdc++.h>
using namespace std;
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace __gnu_pbds;
typedef tree<long long,null_type,less_equal<long long>,rb_tree_tag,tree_order_statistics_node_update> indexed_set;
#pragma GCC optimize("Ofast")
//#pragma GCC target("avx2")
//order_of_key #of elements less than x
// find_by_order kth element
using ll=long long;
using ld=long double;
using pii=pair<ll,ll>;
#define f first
#define s second
#define pb push_back
#define REP(i,n) for(int i=0;i<n;i++)
#define REP1(i,n) for(ll i=1;i<=n;i++)
#define FILL(n,x) memset(n,x,sizeof(n))
#define ALL(_a) _a.begin(),_a.end()
#define sz(x) (int)x.size()
#define SORT_UNIQUE(c) (sort(c.begin(),c.end()),c.resize(distance(c.begin(),unique(c.begin(),c.end()))))
const ll maxn=5e5+5;
const ll maxlg=__lg(maxn)+2;
const ll INF64=4e18;
const int INF=0x3f3f3f3f;
const ll MOD=998244353;
const ld PI=acos(-1.0L);
const ld eps=1e-6;
#define lowb(x) x&(-x)
#define MNTO(x,y) x=min(x,(__typeof__(x))y)
#define MXTO(x,y) x=max(x,(__typeof__(x))y)
template<typename T1,typename T2>
ostream& operator<<(ostream& out,pair<T1,T2> P){
    out<<P.f<<' '<<P.s;
    return out;
}
template<typename T>
ostream& operator<<(ostream& out,vector<T> V){
    REP(i,sz(V)) out<<V[i]+1<<((i!=sz(V)-1)?" ":"\n");
    return out;
}
ll mult(ll a,ll b){
    return a*b%MOD;
}
ll mult(ll a,ll b,ll mod){
    ll res=0;
    while(b){
        if(b&1) res=(res+a)%mod;
        a=(a+a)%mod;
        b>>=1;
    }
    return res;
}
ll mypow(ll a,ll b,ll mod){
    if(b<=0) return 1;
    a%=mod;
    ll res=1LL;
    while(b){
        if(b&1) res=(res*a)%mod;
        a=(a*a)%mod;
        b>>=1;
    }
    return res;
}
int par[maxn],dep[maxn];
vector<int> v[maxn],comp[maxn];
struct UFDS{
    int par[maxn];
    void init(int n){
        REP(i,n) par[i]=i;
    }
    void merge(int a,int b){
        a=find(a),b=find(b);
        if(a==b) return;
        par[a]=b;
    }
    int find(int u){
        if(par[u]==u) return u;
        return par[u]=find(par[u]);
    }
}uf;
void dfs(int u,int p){
    par[u]=p;
    for(int x:v[u]){
        if(x==p) continue;
        dep[x]=dep[u]+1;
        dfs(x,u);
    }
}
void merge(int a,int b){
    a=uf.find(a),b=uf.find(b);
    while(a!=b){
        if(dep[a]>=dep[b]){
            uf.merge(a,par[a]);
            a=uf.find(a);
        }
        else{
            uf.merge(b,par[b]);
            b=uf.find(b);
        }
    }
}
int cnt[maxn];
int32_t main(){
    ios::sync_with_stdio(false),cin.tie(0);
    int n,k;
    cin>>n>>k;
    REP(i,n-1){
        int a,b;
        cin>>a>>b;
        --a,--b;
        v[a].pb(b),v[b].pb(a);
    }
    dfs(0,-1);
    REP(i,n){
        int x;
        cin>>x;
        --x;
        comp[x].pb(i);
    }
    uf.init(n);
    REP(i,n){
        for(int x:comp[i]){
            merge(x,comp[i][0]);
        }
    }
    REP(i,n){
        for(int x:v[i]){
            if(uf.find(i)!=uf.find(x)){
                cnt[uf.find(i)]++,cnt[uf.find(x)]++;
            }
        }
    }
    int ans=0;
    REP(i,n) if(uf.find(i)==i and cnt[i]==2) ++ans;
    cout<<(ans+1)/2;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...