#include<bits/stdc++.h>
using namespace std;
vector<int>adj[500005];
int s[500005];
int pa[20][500005];
int lv[500005];
vector<int>pos[500005];
int p[500005];
int fp(int a){return p[a]==a?a:p[a]=fp(p[a]);}
void un(int a,int b){return p[fp(a)]=fp(b),void();}
int sum[500005];
vector<int>tr[500005];
void dfs(int u,int p){
pa[0][u]=p;
lv[u]=lv[p]+1;
for(int i=1;i<20;i++)pa[i][u]=pa[i-1][pa[i-1][u]];
for(auto x:adj[u])if(x!=p)dfs(x,u);
}
void efs(int u,int p){
for(auto x:adj[u])if(x!=p)efs(x,u),sum[u]+=sum[x];
//cerr<<"u:"<<u<<" "<<sum[u]<<"\n";
if(sum[u])un(u,p);
}
int lca(int a,int b){
if(lv[a]<lv[b])swap(a,b);
for(int i=19;i>=0;i--)if(lv[pa[i][a]]>=lv[b])a=pa[i][a];
if(a==b)return a;
for(int i=19;i>=0;i--)if(pa[i][a]!=pa[i][b])a=pa[i][a],b=pa[i][b];
return pa[0][a];
}
int main(){
ios_base::sync_with_stdio(false);
cin.tie(0);
int n,k;cin>>n>>k;
vector<pair<int,int>>e;
for(int i=0;i<n-1;i++){
int a,b;cin>>a>>b;
adj[a].push_back(b);
adj[b].push_back(a);
e.push_back({a,b});
}
for(int i=1;i<=n;i++)cin>>s[i],pos[s[i]].push_back(i),p[i]=i;
dfs(1,0);
for(int i=1;i<=k;i++){
for(int j=0;j<pos[i].size();j++){
int nxt=(j+1)%pos[i].size();
int c=lca(pos[i][j],pos[i][nxt]);
sum[pos[i][j]]++;
sum[pos[i][nxt]]++;
sum[c]-=2;
//cerr<<pos[i][j]<<' '<<pos[i][nxt]<<" "<<c<<"\n";
}
}
efs(1,0);
for(auto [a,b]:e){
if(fp(a)!=fp(b)){
tr[fp(a)].push_back(fp(b));
tr[fp(b)].push_back(fp(a));
//cerr<<fp(a)<<" "<<fp(b)<<"\n";
}
}
int cnt=0;
for(int i=1;i<=n;i++){
if(tr[i].size()==1){
cnt++;
}
}
//cerr<<cnt<<"\n";
cout<<(cnt+1)/2;
}