#include<bits/stdc++.h>
using namespace std;
const int maxn = 5e5 + 5;
vector<int> g[maxn];
int tin[maxn], out[maxn], timedfs;
int par[maxn], p[maxn];
int find(int u){
return p[u] == u ? u : p[u] = find(p[u]);
}
void dfs(int u, int c){
par[u] = c;
tin[u] = ++timedfs;
for(int v: g[u]) if(v != c) dfs(v, u);
out[u] = timedfs;
}
bool inside(int x, int y){//ktra x co trong y ko
return tin[y] <= tin[x] && out[x] <= out[y];
}
signed main(){
ios_base::sync_with_stdio(false);
cin.tie(0); cout.tie(0);
int n, k; cin >> n >> k;
for(int i = 1; i <= n - 1; i++){
int u, v; cin >> u >> v;
g[u].push_back(v);
g[v].push_back(u);
}
dfs(1, 0);
vector<vector<int>> group(k + 1);
for(int i = 1; i <= n; i++){
int s; cin >> s;
group[s].push_back(i);
p[i] = i;
}
for(int i = 1; i <= k; i++){
for(int j = 0; j + 1 < group[i].size(); j++){
int u = group[i][j], v = group[i][j + 1];
while(!inside(v, u)){
int x = find(par[u]);
p[u] = x;
u = x;
}
while(!inside(u, v)){
int x = find(par[v]);
p[v] = x;
v = x;
}
}
}
vector<int> c(n + 1);
for(int i = 1; i <= n; i++){
for(int j: g[i]){
if(find(i) != find(j)) c[i]++;
}
}
int ans = 0;
for(int i = 1; i <= n; i++) ans += (c[i] == 1);
cout << (ans + 1) / 2;
}