#include<bits/stdc++.h>
using namespace std;
#define int long long
const int maxn = 2e5 + 5;
int cnt[maxn], del[maxn], par[maxn], vis[maxn], sz[maxn], a[maxn], ans = 1e9;
vector<int> col[maxn], g[maxn], node;
void dfs(int u, int p){
sz[u] = 1;
for(int v: g[u]){
if(v == p || del[v]) continue;
dfs(v, u);
sz[u] += sz[v];
}
}
int centroid(int u, int p, int tot){
for(int v: g[u]){
if(v != p && !del[v] && sz[v] > tot / 2)return centroid(v, u, tot);
}
return u;
}
void get(int u, int p){
node.push_back(u);
par[u] = p;
for(int v: g[u]){
if(v != p && !del[v]) get(v, u);
}
}
void decompose(int u){
dfs(u, -1);
int c = centroid(u, -1, sz[u]);
get(c, -1);
del[c] = true;
for(int x: node) col[a[x]].push_back(x);
queue<int> q;
vis[a[c]] = 1;
q.push(a[c]);
int res = 0, ok = 1;
while(q.size() && ok){
int x = q.front();
q.pop();
if(cnt[x] != col[x].size()) ok = 0;
for(int v: col[x]){
if(par[v] != -1 && !vis[a[par[v]]]){
vis[a[par[v]]] = 1;
q.push(a[par[v]]);
res++;
}
}
}
for(int x: node){
col[a[x]].clear();
vis[a[x]] = 0;
}
node.clear();
if(ok) ans = min(ans, res);
for(int v: g[c]) if(!del[v]) decompose(v);
}
signed main(){
ios_base::sync_with_stdio(false);
cin.tie(0); cout.tie(0);
int n, k; cin >> n >> k;
for(int i = 1; i <= n - 1; i++){
int u, v; cin >> u >> v;
g[u].push_back(v);
g[v].push_back(u);
}
for(int i = 1; i <= n; i++) cin >> a[i], cnt[a[i]]++;
decompose(1);
cout << ans;
}