#include <bits/stdc++.h>
using namespace std;
#define all(v) v.begin(), v.end()
typedef long long ll;
const int NMAX = 5e5 + 5;
int n, k, a, b, p[NMAX], d[NMAX], sheep[NMAX], vis[NMAX], dist[NMAX];
vector<int> adj[NMAX], ans;
queue<int> q;
vector<pair<int, int>> v;
void dfs(int x, int par) {
for (int& nx : adj[x]) {
if (nx == par) continue;
d[nx] = d[x] + 1;
p[nx] = x;
dfs(nx, x);
}
}
int main(void) {
ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
cin >> n >> k;
for (int i = 1; i < n; i++) {
cin >> a >> b;
adj[a].emplace_back(b);
adj[b].emplace_back(a);
}
memset(dist, -1, sizeof(dist));
for (int i = 0; i < k; i++) {
cin >> a; sheep[a] = 1;
q.emplace(a);
dist[a] = 0;
}
dfs(1, -1);
while (q.size()) {
int x = q.front(); q.pop();
for (int& nx : adj[x])
if (dist[nx] == -1) {
dist[nx] = dist[x] + 1;
q.emplace(nx);
}
}
for (int i = 1; i <= n; i++)
if (sheep[i]) v.emplace_back(d[i], i);
sort(all(v)); reverse(all(v));
for (auto&[_, s] : v) {
if (vis[s]) continue;
int x = s;
while (x > 1 && dist[p[x]] == dist[x] + 1) x = p[x];
ans.emplace_back(x);
queue<int> q;
q.emplace(x); vis[x] = 1;
while (q.size()) {
int y = q.front(); q.pop();
for (int& nx : adj[y])
if (!vis[nx] && dist[nx] == dist[y] - 1) {
vis[nx] = 1;
q.emplace(nx);
}
}
}
cout << ans.size() << '\n';
for (int& x : ans) cout << x << ' ';
return 0;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |