#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
const int MAXN = 500005;
const int LOGN = 20;
vector<int> adj[MAXN];
vector<int> state_cities[MAXN];
int up[MAXN][LOGN], depth[MAXN];
int diff[MAXN], parent[MAXN], deg[MAXN];
int n, k;
// LCA uchun tayyorgarlik
void dfs_lca(int u, int p, int d) {
depth[u] = d;
up[u][0] = p;
for (int i = 1; i < LOGN; i++)
up[u][i] = up[up[u][i - 1]][i - 1];
for (int v : adj[u]) {
if (v != p) dfs_lca(v, u, d + 1);
}
}
int get_lca(int u, int v) {
if (depth[u] < depth[v]) swap(u, v);
for (int i = LOGN - 1; i >= 0; i--) {
if (depth[u] - (1 << i) >= depth[v]) u = up[u][i];
}
if (u == v) return u;
for (int i = LOGN - 1; i >= 0; i--) {
if (up[u][i] != up[v][i]) {
u = up[u][i];
v = up[v][i];
}
}
return up[u][0];
}
// Daraxt bo'ylab farqlar massivini yig'ish (yo'llarni belgilash)
void dfs_diff(int u, int p) {
for (int v : adj[u]) {
if (v != p) {
dfs_diff(v, u);
diff[u] += diff[v];
}
}
}
// DSU (Disjoint Set Union) guruhlarni birlashtirish uchun
int find_set(int v) {
if (v == parent[v]) return v;
return parent[v] = find_set(parent[v]);
}
void unite(int u, int v) {
u = find_set(u);
v = find_set(v);
if (u != v) parent[u] = v;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(NULL);
if (!(cin >> n >> k)) return 0;
for (int i = 0; i < n - 1; i++) {
int u, v;
cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
for (int i = 1; i <= n; i++) {
int s; cin >> s;
state_cities[s].push_back(i);
parent[i] = i;
}
dfs_lca(1, 1, 0);
// Bir xil shtatdagi shaharlar yo'lini belgilash
for (int i = 1; i <= k; i++) {
if (state_cities[i].empty()) continue;
int lca = state_cities[i][0];
for (int j = 1; j < state_cities[i].size(); j++) {
lca = get_lca(lca, state_cities[i][j]);
}
for (int city : state_cities[i]) {
diff[city]++;
diff[lca]--;
}
}
dfs_diff(1, 1);
// "Yopiq" qirralar orqali shaharlarni bitta guruhga birlashtirish
for (int i = 1; i <= n; i++) {
if (diff[i] > 0) { // i va uning otasi orasidagi qirra yopiq
unite(i, up[i][0]);
}
}
// Yangi daraxtdagi tugunlar darajasini hisoblash
for (int u = 1; u <= n; u++) {
for (int v : adj[u]) {
int root_u = find_set(u);
int root_v = find_set(v);
if (root_u != root_v) {
deg[root_u]++;
}
}
}
int leaves = 0;
int group_count = 0;
for (int i = 1; i <= n; i++) {
if (parent[i] == i) {
group_count++;
if (deg[i] == 1) leaves++;
}
}
// Agar hamma bitta guruhga birlashgan bo'lsa javob 0 [cite: 66]
if (group_count <= 1) cout << 0 << endl;
else cout << (leaves + 1) / 2 << endl;
return 0;
}