#include<bits/stdc++.h>
using namespace std;
vector<int> t, depth;
vector<vector<int>> adj, col, par;
void dfs(int node, int prev, int d) {
depth[node]=d;
par[node][0]=prev;
col[t[node]].push_back(node);
for(auto child : adj[node]) if(child != prev) dfs(child, node, d+1);
}
int ancestor(int x, int k) {
int cnt=0;
while(k&&x) {
if(k&1) x=par[x][cnt];
k=k>>1; cnt++;
}
return x;
}
int lca(int a, int b) {
if(depth[a] < depth[b]) swap(a, b);
a=ancestor(a, depth[a]-depth[b]);
if(a == b) return a;
for(int k=19; k>=0; k--) {
int aa=par[a][k];
int bb=par[b][k];
if(aa != bb) {
a=aa; b=bb;
}
}
return par[a][0];
}
struct DSU{
vector<int> e;
DSU(int N) {e=vector<int>(N, -1);}
int get(int x) {return e[x] < 0 ? x : e[x]=get(e[x]);}
int size(int x) {return -e[get(x)];}
bool unite(int a, int b) {
a=get(a); b=get(b);
if(a == b) return false;
if(e[a] > e[b]) swap(a, b);
e[a]+=e[b]; e[b]=a;
return true;
}
};
void up(int node, int tar, DSU&dsu) {
while(node != tar) {
if(!dsu.unite(node, par[node][0])) return;
node=par[node][0];
}
}
vector<int> componentIds;
vector<vector<int>> adjComponents;
void dfsComponents(int node, int prev) {
if (prev && componentIds[prev] != componentIds[node]) {
adjComponents[componentIds[prev]].push_back(componentIds[node]);
adjComponents[componentIds[node]].push_back(componentIds[prev]);
}
for(auto child: adj[node]) if (child != prev) {
dfsComponents(child, node);
}
}
signed main() {
int n, k; cin >> n >> k;
adj=vector<vector<int>>(n+1);
for(int i=1; i<n; i++) {
int a, b; cin >> a >> b;
adj[a].push_back(b);
adj[b].push_back(a);
}
t=depth=vector<int>(n+1);
for(int i=1; i<=n; i++) cin >> t[i];
col=vector<vector<int>>(k+1);
par=vector<vector<int>>(n+1, vector<int>(20, 0));
dfs(1, 0, 0);
for(int j=1; j<20; j++) {
for(int i=1; i<=n; i++) {
if(par[i][j-1]) par[i][j]=par[par[i][j-1]][j-1];
}
}
DSU dsu(n+1);
for(int i=1; i<=k; i++) {
int top=col[i][0];
for(int j=1; j<col[i].size(); j++) top=lca(top, col[i][j]);
for(int j=0; j<col[i].size(); j++) up(col[i][j], top, dsu);
}
// for(int i=0; i<n; i++) cout << dsu.get(i) << endl;
// cout << endl;
componentIds = vector<int>(n+1);
adjComponents = vector<vector<int>>(n+1);
for(int i=1; i<=n; i++) {
componentIds[i] = dsu.get(i);
}
dfsComponents(1, 0);
int ans = 0;
for(int i=1; i<=n; i++) {
if (adjComponents[i].size() == 1) {
ans++;
}
}
ans = (ans + 1)/2;
cout << ans << '\n';
// int leaves=0;
// for(int i=1; i<=n; i++) {
// int ii=dsu.get(i), nbCon=0;
// for(auto x : adj[i]) {
// int xx=dsu.get(x);
// if(ii != xx) {
// nbCon++;
// }
// }
// if(nbCon == 1) leaves++;
// }
// cout << (leaves+1)/2;
return 0;
}