#include<bits/stdc++.h>
using namespace std;
const int N = 5E5 + 1 , LOG = 19;
int up[N][LOG] , in[N] , out[N] , d[N] , s[N] , pa[N] , c[N] , T = 0;
vector<int> g[N] , e[N];
int find(int u){
return u == pa[u] ? u : pa[u] = find(pa[u]);
}
bool same(int u , int v){
return find(u) == find(v);
}
void join(int u , int v){
u = find(u);
v = find(v);
if(u != v){
pa[v] = u;
}
}
void dfs(int u , int P){
up[u][0] = P;
d[u] = d[P] + 1;
in[u] = ++T;
for(int v : g[u]){
if(v != P){
dfs(v , u);
}
}
out[u] = T;
}
bool parent(int u , int v){
return in[u] <= in[v] && out[v] <= out[u];
}
int lca(int u , int v){
if(d[u] < d[v]){
swap(u , v);
}
for(int b = 0;b < LOG;b ++){
if((d[u] - d[v]) >> b & 1){
u = up[u][b];
}
}
if(u == v){
return u;
}
for(int b = LOG - 1;b >= 0;b --){
if(up[u][b] != up[v][b]){
u = up[u][b];
v = up[v][b];
}
}
return up[u][0];
}
void DFS(int u){
for(int v : g[u]){
if(v != up[u][0]){
DFS(v);
s[u] += s[v];
}
}
if(s[u] > 0){
join(u , up[u][0]);
}
}
int main(){
cin.tie(0)->sync_with_stdio(0);
int n , K;
cin >> n >> K;
for(int i = 1;i < n;i ++){
int u , v;
cin >> u >> v;
g[u].emplace_back(v);
g[v].emplace_back(u);
}
dfs(1 , 1);
for(int b = 1;b < LOG;b ++){
for(int i = 1;i <= n;i ++){
up[i][b] = up[up[i][b - 1]][b - 1];
}
}
for(int i = 1;i <= n;i ++){
int a;
cin >> a;
e[a].emplace_back(i);
}
for(int a = 1;a <= K;a ++){
sort(e[a].begin() , e[a].end() , [&](int x , int y){
return in[x] < in[y];
});
vector<int> vr = e[a];
for(int i = 0;i < (int)e[a].size() - 1;i ++){
vr.emplace_back(lca(e[a][i] , e[a][i + 1]));
}
sort(vr.begin() , vr.end() , [&](int x , int y){
return in[x] < in[y];
});
stack<int> S;
for(int i = 0;i < (int)vr.size();i ++){
while(!S.empty() && !parent(S.top() , vr[i])){
S.pop();
}
if(!S.empty()){
s[vr[i]] += 1;
s[S.top()] -= 1;
}
S.push(vr[i]);
}
}
iota(pa , pa + n + 1 , 0);
DFS(1);
for(int u = 1;u <= n;u ++){
for(int v : g[u]){
if(!same(u , v)){
c[find(u)] += 1;
}
}
}
int ans = 0;
for(int i = 1;i <= n;i ++){
ans += c[i] == 1;
}
cout << (ans + 1) / 2;
}