//#pragma GCC optimize("O3,unroll-loops")
//#pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include <bits/stdc++.h>
using namespace std;
#define ff first
#define ss second
#define pb push_back
#define pf push_front
#define mp make_pair
#define ll long long
vector<vector<int>> graph(500001);
int state[500001], parent[500001][20], depth[500001], lcastate[500001];
bool isok[500001];
vector<int> statevec[500001], revlcastate[500001], curprocessed;
void dfs(int a) {
for (int i: graph[a]) {
if (i == parent[a][0]) continue;
parent[i][0] = a;
depth[i] = depth[a] + 1;
dfs(i);
}
}
void notok(int a) {
//cout << "notok" << ' ' << a << endl;
curprocessed.pb(a);
for (int i: revlcastate[a]) {
bool flag = true;
for (int j: statevec[i]) {
if (!isok[j]) {
flag = false;
break;
}
isok[j] = false;
}
if (flag) for (int j: statevec[i]) notok(j);
}
while (isok[parent[a][0]]) {
for (int i: statevec[state[parent[a][0]]]) isok[i] = false;
for (int i: statevec[state[parent[a][0]]]) notok(i);
a = parent[a][0];
}
}
int lca(int a, int b) {
if (a == b) return a;
if (depth[a] < depth[b]) swap(a, b);
for (int i=19;i>-1;i--) if (depth[parent[a][i]] > depth[b]) a = parent[a][i];
a = parent[a][0];
if (a == b) return a;
for (int i=19;i>-1;i--) {
if (parent[a][i] != parent[b][i]) {
a = parent[a][i];
b = parent[b][i];
}
}
return parent[a][0];
}
int main() {
for (int i=0;i<500000;i++) {
isok[i] = true;
}
int n, k, d1, d2; cin >> n >> k;
for (int i=0;i<n-1;i++) {
cin >> d1 >> d2;
d1--; d2--;
graph[d1].pb(d2);
graph[d2].pb(d1);
}
for (int i=0;i<n;i++) {
cin >> state[i];
statevec[state[i]].pb(i);
}
depth[0] = 0, parent[0][0] = 0;
dfs(0);
for (int i=1;i<20;i++) for (int j=0;j<n;j++) parent[j][i] = parent[parent[j][i - 1]][i - 1];
for (int i=1;i<k+1;i++) {
lcastate[i] = statevec[i][0];
for (int j=1;j<statevec[i].size();j++) lcastate[i] = lca(lcastate[i], statevec[i][j]);
revlcastate[lcastate[i]].pb(i);
}
stack<int> nums; nums.push(0);
vector<int> newnodes;
vector<vector<int>> newgraph(n);
int ans = 0, node;
while (nums.size() > 0) {
node = nums.top(); nums.pop();
//cout << node << endl;
for (int i: statevec[state[node]]) isok[i] = false;
for (int i: statevec[state[node]]) notok(i);
for (int i: curprocessed) for (int j: graph[i]) if (isok[j]) newnodes.pb(j);
for (int i: newnodes) {
nums.push(i);
newgraph[node].pb(i);
newgraph[i].pb(node);
}
newnodes.clear();
curprocessed.clear();
}
for (int i=0;i<n;i++) if (newgraph[i].size() == 1) ans += 1;
cout << (ans + 1) / 2;
}