#include <bits/stdc++.h>
using namespace std;
using ll = long long; using pii = pair<ll,ll>;
pii fz(pii p1, pii p2) {
if (p1.second<p2.second) {
return p1;
}
return p2;
}
const ll Nm = (1LL<<20); const ll E = 20; const ll INF = 1e18;
pii st[2*Nm];
vector<ll> focc(Nm,-1); //first occurrence
ll v2(ll x) {
return __builtin_ctz(x);
}
pii qry(ll a, ll b) {
if (a>b) {
return {INF,INF};
}
ll va = v2(a); ll vb = v2(b+1);
if (va<vb) {
return fz(st[(a>>va)+(1<<(E-va))],qry(a+(1<<va),b));
} else {
return fz(qry(a,b-(1<<vb)),st[(b>>vb)+(1<<(E-vb))]);
}
}
ll lca(ll a, ll b) {
a = focc[a]; b = focc[b];
assert(a!=-1); assert(b!=-1);
if (a>b) {
swap(a,b);
}
return qry(a,b).first;
}
ll dsuf[Nm];
ll dsz[Nm];
ll getf(ll a) {
if (dsuf[a]==a) {
return a;
}
ll x = getf(dsuf[a]);
dsuf[a]=x;
return x;
}
void mrg(ll a, ll b) {
a = getf(a); b = getf(b);
if (a==b) {
return;
}
if (dsz[a]>dsz[b]) {
swap(a,b);
}
dsuf[a]=b;
dsz[b]+=dsz[a];
}
int main() {
ios_base::sync_with_stdio(false); cin.tie(0);
ll N,K; cin >> N >> K;
vector<ll> S(N);
vector<ll> adj[N];
ll radj[N];
vector<ll> fadj[N];
ll ht[N];
for (ll i=0;i<(N-1);i++) {
ll a,b; cin >> a >> b;
a--; b--;
adj[a].push_back(b);
adj[b].push_back(a);
}
vector<ll> kloc[K];
ll ktop[K];
for (ll i=0;i<N;i++) {
cin >> S[i];
S[i]--;
kloc[S[i]].push_back(i);
dsuf[i]=i;
dsz[i]=1;
}
ll ROOT = N-1;
radj[ROOT]=-1;
ht[ROOT]=0;
stack<pii> q0; //{element, 0=process}
q0.push({ROOT,0});
vector<pii> fv; //final vector: {point, height}
vector<ll> dfs;
while (!q0.empty()) {
pii p0 = q0.top(); q0.pop();
ll x = p0.first;
if (p0.second==0) {
dfs.push_back(x);
focc[x]=fv.size();
fv.push_back({x,ht[x]});
for (ll y: adj[x]) {
if (y==radj[x]) {
continue;
}
radj[y]=x;
fadj[x].push_back(y);
ht[y]=ht[x]+1;
q0.push({x,1});
q0.push({y,0});
}
} else {
fv.push_back({x,ht[x]});
}
}
for (ll i=0;i<((ll)fv.size());i++) {
st[i+Nm]=fv[i];
}
for (ll p=(Nm-1);p>=1;p--) {
st[p]=fz(st[2*p],st[2*p+1]);
}
vector<ll> diff(N,0); //+1: add seg going up, etc
for (ll k=0;k<K;k++) {
if (kloc[k].size()==0) {
continue;
}
ktop[k]=kloc[k][0];
for (ll x: kloc[k]) {
ktop[k]=lca(x,ktop[k]);
}
//cout << "k="<<k<<", ktop="<<ktop[k]<<"\n";
//cout << "ht[ktop]="<<ht[ktop[k]]<<"\n";
for (ll x: kloc[k]) {
diff[x]++;
//cout << "ht[x]="<<ht[x]<<"\n";
}
diff[ktop[k]]-=kloc[k].size();
}
for (ll T=(N-1);T>=1;T--) {
ll x = dfs[T];
for (ll y: fadj[x]) {
diff[x]+=diff[y];
}
if (diff[x]!=0) {
mrg(x,radj[x]);
}
}
vector<ll> cnt(N,0);
for (ll T=(N-1);T>=1;T--) {
ll x = dfs[T];
if (diff[x]==0) {
cnt[getf(x)]++;
cnt[getf(radj[x])]++;
}
}
ll leaf = 0;
for (ll i=0;i<N;i++) {
if (cnt[i]==1) {
leaf++;
}
}
ll ans = 0;
if (leaf>0) {
ans = (leaf+1)/2;
}
cout << ans << "\n";
}