//#pragma GCC optimize("O3")
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define pii pair<int,int>
#define f first
#define s second
#define all(x) x.begin(),x.end()
#define _ ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
int dx[4]={0,0,1,-1};
int dy[4]={1,-1,0,0};
void setIO(string s) {
freopen((s + ".in").c_str(), "r", stdin);
freopen((s + ".out").c_str(), "w", stdout);
}
const int mxn=4e5+5;
const int inf=1e9;
vector<int> adj[mxn];
bool visited[mxn];
int c[mxn];
vector<int> segtree(mxn*4,inf);
int n,k;
void update(int pos,int val,int l=1,int r=n,int v=1){
if(l==r){
segtree[v]=min(segtree[v],val);
return;
}
int mid=(l+r)/2;
if(pos<=mid) update(pos,val,l,mid,v*2);
else update(pos,val,mid+1,r,v*2+1);
segtree[v]=min(segtree[v*2],segtree[v*2+1]);
}
int query(int tl,int tr,int l=1,int r=n,int v=1){
if(r<tl or tr<l){
return inf;
}
if(tl<=l and r<=tr){
return segtree[v];
}
int mid=(l+r)/2;
return min(query(tl,min(mid,tr),l,mid,v*2),query(max(mid+1,tl),tr,mid+1,r,v*2+1));
}
int main() {_
cin>>n>>k;
for(int i=0;i<n-1;i++){
int a,b;
cin>>a>>b;
adj[a].push_back(b);
adj[b].push_back(a);
}
for(int i=1;i<=n;i++){
cin>>c[i];
c[i]--;
}
int st;
for(int i=1;i<=n;i++){
if((int)adj[i].size()==1) st=i;
}
vector<int> ord;
ord.push_back(-1);
queue<int> q;
q.push(st);
while(!q.empty()){
int v=q.front();
q.pop();
visited[v]=true;
ord.push_back(v);
for(auto u:adj[v]){
if(visited[u]) continue;
q.push(u);
}
}
vector<int> mnn(k),mxx(k);
{
vector<int> mn(k,inf),mx(k,-1);
for(int i=1;i<ord.size();i++){
mx[c[ord[i]]]=max(mx[c[ord[i]]],i);
mn[c[ord[i]]]=min(mn[c[ord[i]]],i);
}
for(int i=1;i<ord.size();i++){
int C=c[ord[i]];
int val=query(mn[C],i);
val=min(val,mn[C]);
if(i==mx[C]){
mnn[C]=val;
}
update(i,val);
}
}
reverse(ord.begin()+1,ord.end());
segtree=vector<int>(mxn*4,inf);
{
vector<int> mn(k,inf),mx(k,-1);
for(int i=1;i<ord.size();i++){
int C=c[ord[i]];
mx[C]=max(mx[C],i);
mn[C]=min(mn[C],i);
}
for(int i=1;i<ord.size();i++){
int C=c[ord[i]];
int val=query(mn[C],i);
val=min(val,mn[C]);
if(i==mx[C]){
mxx[C]=val;
}
update(i,val);
}
}
reverse(ord.begin()+1,ord.end());
vector<int> mn(k,inf),mx(k,-1);
for(int i=1;i<ord.size();i++){
mx[c[ord[i]]]=max(mx[c[ord[i]]],i);
mn[c[ord[i]]]=min(mn[c[ord[i]]],i);
}
for(int i=0;i<k;i++){
mxx[i]=n-mxx[i]+1;
}
vector<int> pre(ord.size());
for(int i=1;i<ord.size();i++){
int C=c[ord[i]];
pre[i]=pre[i-1];
if(i==mx[C]){
pre[i]++;
}
}
int ans=inf;
for(int i=0;i<k;i++){
assert(mnn[i]>=1 and mnn[i]<=n and mxx[i]>=1 and mxx[i]<=n);
ans=min(ans,pre[mxx[i]]-pre[mnn[i]-1]-1);
}
cout<<ans<<'\n';
return 0;
}
//maybe its multiset not set
//yeeorz
//laborz