#include <bits/stdc++.h>
using namespace std;
#define short int32_t
#define int int64_t
#define long __int128_t
const int inf{numeric_limits<int>::max() / 4};
short main() {
#ifndef LOCAL
ios_base::sync_with_stdio(false);
cin.tie(nullptr);
#endif
int n, k;
cin >> n >> k;
vector<vector<int>> edges(n);
for (int i{0}; i < n - 1; i++) {
int a, b;
cin >> a >> b;
a--;
b--;
edges[a].push_back(b);
edges[b].push_back(a);
}
vector<int> parents(n);
vector<int> depths(n);
vector<int> jumps(n);
function<void(int, int, int)> dfs = [&](int node, int parent, int depth) -> void {
parents[node] = parent;
depths[node] = depth;
jumps[node] = parent == -1
? node
: depths[jumps[jumps[parent]]] - depths[jumps[parent]] == depths[jumps[parent]] - depths[parent]
? jumps[jumps[parent]]
: parent;
for (int to : edges[node]) {
if (to != parent) {
dfs(to, node, depth + 1);
}
}
};
dfs(0, -1, 0);
auto get_lca = [&](int a, int b) {
if (depths[a] < depths[b]) {
swap(a, b);
}
while (depths[a] > depths[b]) {
a = depths[jumps[a]] >= depths[b] ? jumps[a] : parents[a];
}
while (a != b) {
if (jumps[a] == jumps[b]) {
a = parents[a];
b = parents[b];
} else {
a = jumps[a];
b = jumps[b];
}
}
return a;
};
vector<int> s(n);
vector<int> s_lca(k, -1);
for (int i{0}; i < n; i++) {
cin >> s[i];
s[i]--;
s_lca[s[i]] = s_lca[s[i]] == -1 ? i : get_lca(s_lca[s[i]], i);
}
vector<int> dsu(n, -1);
function<int(int)> find = [&](int node) {
return dsu[node] < 0 ? node : (dsu[node] = find(dsu[node]));
};
auto join = [&](int a, int b) {
a = find(a);
b = find(b);
if (a == b) {
return;
}
if (depths[a] > depths[b]) {
swap(a, b);
}
dsu[b] = a;
};
auto walk = [&](int child, int parent) {
while (depths[child] > depths[parent]) {
join(child, parents[child]);
child = find(child);
}
};
for (int i{0}; i < n; i++) {
walk(i, s_lca[s[i]]);
}
vector<set<int>> comp_edges(n);
for (int i{0}; i < n; i++) {
for (int to : edges[i]) {
if (find(i) != find(to)) {
comp_edges[find(i)].insert(find(to));
}
}
}
int leaves{0};
for (int i{0}; i < n; i++) {
if (comp_edges[i].size() == 1) {
leaves++;
}
}
cout << (leaves + 1) / 2 << "\n";
return 0;
}