#pragma gcc diagnostic "-std=c++1z"
#include <bits/stdc++.h>
#define int long long
#define pb push_back
#define f first
#define s second
#define MOD 1000000007
#define pii pair<int,int>
#define all(x) (x).begin(),(x).end()
#define allr(x) (x).rbegin(),(x).rend()
using namespace std;
const int N=5e5+5;
int n,m,T,k,c[N],anc[N][22],depth[N],s[N],cnt[N];
vector<int> g[N],e[N];
void dfs0(int x, int p) {
depth[x]=depth[p]+1; anc[x][0]=p;
for (auto y:g[x]) {
if (y==p) continue;
dfs0(y,x);
}
}
int ancestor(int a, int b) {
int res=a;
for (int j=0; j<=19; j++) {
if (b&(1<<j)) res=anc[res][j];
}
return res;
}
int LCA(int a, int b) {
if (depth[a]<depth[b]) swap(a,b);
a=ancestor(a,depth[a]-depth[b]);
for (int j=19; j>=0; j--) {
if (anc[a][j]!=anc[b][j]) {
a=anc[a][j]; b=anc[b][j];
}
}
return anc[a][0];
}
int p[N],sz[N];
void make_set(int v) {
p[v]=v; sz[v]=1;
}
int find_set(int v) {
if (p[v]==v) return v;
return p[v]=find_set(p[v]);
}
void merge_set(int a, int b) {
a=find_set(a); b=find_set(b);
if (a==b) return;
if (sz[a]<sz[b]) swap(a,b);
p[b]=a; sz[a]+=sz[b];
}
void dfs(int x, int p) {
for (auto y:g[x]) {
if (y==p) continue;
dfs(y,x); s[x]+=s[y];
}
if (s[x]) merge_set(c[x],c[p]);
}
void dfs1(int x, int p) {
if (find_set(x)!=find_set(p)) {
cnt[find_set(x)]++;
cnt[find_set(p)]++;
}
for (auto y:g[x]) {
if (y==p) continue;
dfs1(y,x);
}
}
void test_case() {
cin>>n>>k;
for (int i=1; i<=n-1; i++) {
int a,b; cin>>a>>b;
g[a].pb(b); g[b].pb(a);
}
for (int i=1; i<=n; i++) {
cin>>c[i]; e[c[i]].pb(i);
}
dfs0(1,1);
for (int j=1; j<=19; j++) {
for (int i=1; i<=n; i++) {
anc[i][j]=anc[anc[i][j-1]][j-1];
}
}
for (int i=1; i<=k; i++) {
if (e[i].size()==0) continue;
int l=e[i][0];
for (int j=1; j<e[i].size(); j++) {
l=LCA(e[i][j],l);
}
for (auto a:e[i]) {
s[a]++; s[l]--;
}
}
for (int i=1; i<=k; i++) make_set(i);
dfs(1,1); dfs1(1,1);
int num=0;
for (int i=1; i<=k; i++) {
if (cnt[i]==1) num++;
}
if (num==0) cout<<0<<endl;
else cout<<num-1<<endl;
}
main () {
ios :: sync_with_stdio(0);
cin.tie(0); cout.tie(0);
T=1;
while (T--) test_case();
}