#include<bits/stdc++.h>
#define int long long
#define pb push_back
#define eb emplace_back
#define mp make_pair
#define F first
#define S second
#define vi vector<int>
#define pii pair<int,int>
#define f0r(i,n) for(int i = 0; i < n; i++)
#define FOR(i,k,n) for(int i = k; i < n; i++)
#define dout(x) cout<<x<<' '<<#x<<endl
#define dout2(x,y) cout<<x<<' '<<#x<<' '<<y<<' '<<#y<<endl
#define vout(v) cout<<#v<<": "; for(auto u : v)cout<<u<<' '; cout<<endl
using namespace std;
const int mxn = 2e5 + 5;
int n, k, c[mxn], sz[mxn], vis[mxn], isc[mxn], cvis[mxn], par[mxn], ans = 1e9; vi adj[mxn], visc, cc;
vector<vi>colnode;
void get_sz(int node, int from){
sz[node] = 1; for(auto u : adj[node])if(u!=from&&!vis[u])get_sz(u,node),sz[node]+=sz[u];
}
int get_centroid(int node, int from, int tar){
for(auto u : adj[node])if(u!=from&&!vis[u]&&sz[u] > tar)return get_centroid(u,node,tar); return node;
}
void dfs(int node, int from){
par[node] = from; cc.pb(node);
isc[node] = 1; for(auto u : adj[node])if(u!=from&&!vis[u])dfs(u,node);
}
void centroid_decomp(int node){
get_sz(node, -1); int ce = get_centroid(node, -1, sz[node] / 2); //vout(isc);
node = ce; cc.clear(); dfs(node, -1);// dout(node); vout(isc); vout(cvis);
queue<int>q; q.push(node); cvis[node] = 1; int ct = 0; vi cols; bool ok = 1; while(!q.empty()){
int cur = q.front(); q.pop(); if(!isc[cur]){ok=0;break;} if(!visc[c[cur]]){
visc[c[cur]] = 1; ct++; cols.pb(c[cur]); for(auto u : colnode[c[cur]])if(!cvis[u]){
if(isc[u])cvis[u] = 1; else{ok = 0; break;} q.push(u);
}
}
if(!ok)break;
if(cur != ce){
cur = par[cur]; while(!cvis[cur]){
if(isc[cur])cvis[cur] = 1; else{ok = 0; break;}
q.push(cur); cur = par[cur];
}
}
if(!ok)break;
}
if(ok)ans = min(ans, ct); //dout(ok); vout(cols);
for(auto u : cc)isc[u] = 0, cvis[u] = 0; for(auto u : cols)visc[u] = 0;
vis[ce] = 1; for(auto u : adj[ce])if(!vis[u])centroid_decomp(u);
}
signed main(){
ios::sync_with_stdio(false); cin.tie(NULL);
cin>>n>>k; f0r(i,n-1){int a, b; cin>>a>>b; a--; b--; adj[a].pb(b); adj[b].pb(a);} f0r(i,n)cin>>c[i],c[i]--;
visc.resize(k); colnode.resize(k); f0r(i,n)colnode[c[i]].pb(i);
centroid_decomp(0); cout<<ans-1;
}