#include <bits/stdc++.h>
// #pragma GCC optimize ("Ofast,unroll-loops")
// #pragma GCC target ("avx2")
using namespace std;
typedef long long ll;
typedef pair<int, int> pp;
#define er(args ...) cerr << __LINE__ << ": ", err(new istringstream(string(#args)), args), cerr << endl
#define per(i,r,l) for(int i = (r); i >= (l); i--)
#define rep(i,l,r) for(int i = (l); i < (r); i++)
#define all(x) begin(x), end(x)
#define sz(x) (int)(x).size()
#define pb push_back
#define ss second
#define ff first
void err(istringstream *iss){}template<typename T,typename ...Args> void err(istringstream *iss,const T &_val, const Args&...args){string _name;*iss>>_name;if(_name.back()==',')_name.pop_back();cerr<<_name<<" = "<<_val<<", ",err(iss,args...);}
void IOS(){
cin.tie(0) -> sync_with_stdio(0);
#ifndef ONLINE_JUDGE
freopen("in.in", "r", stdin);
freopen("out.out", "w", stdout);
#endif
}
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
const ll mod = 1e9 + 7, maxn = 5e5 + 5, lg = 22, inf = ll(1e9) + 5;
ll pw(ll a,ll b,ll md=mod){if(!b)return 1;ll k=pw(a,b>>1ll);return k*k%md*(b&1ll?a:1)%md;}
vector<int> adj[maxn], nd;
set<int> st[maxn];
int par[maxn][lg], h[maxn], cnt[maxn], nxt[maxn], vl[maxn], mn[maxn];
bool ban[maxn];
void dfs(int r, int p){
par[r][0] = p;
rep(i,1,lg) par[r][i] = par[par[r][i-1]][i-1];
for(int c: adj[r]) if(c - p) h[c] = h[r] + 1, dfs(c, r);
}
int lca(int u, int v){
if(h[u] > h[v]) swap(u, v);
per(i,lg-1,0) if(h[par[v][i]] >= h[u]) v = par[v][i];
if(u == v) return u;
per(i,lg-1,0) if(par[v][i] - par[u][i]) u = par[u][i], v = par[v][i];
return par[u][0];
}
void dfs2(int r, int p){
cnt[r] = 1;
for(int c: adj[r]) if(c - p && !ban[c]) dfs2(c, r), cnt[r] += cnt[c];
}
int find_cent(int r, int p, int tot){
for(int c: adj[r]) if(c - p && !ban[c] && (cnt[c]<<1) > tot) return find_cent(c, r, tot);
return r;
}
void decom(int r, int tp){
dfs2(r, -1);
r = find_cent(r, -1, cnt[r]);
nxt[r] = tp, ban[r] = true;
for(int c: adj[r]) if(!ban[c]) decom(c, r);
}
int dist(int u, int v){
return h[u] + h[v] - (h[lca(u, v)]<<1);
}
void bfs(){
fill(mn, mn + maxn, inf);
deque<int> q;
for(int c: nd) mn[c] = 0, q.pb(c);
while(sz(q)){
int r = q.front(); q.pop_front();
for(int c: adj[r]) if(mn[c] == inf) mn[c] = mn[r] + 1, q.pb(c);
}
}
int dfs3(int r){
if(r == -1) return 0;
if(vl[r] == inf) return vl[r] = dfs3(nxt[r]) + 1;
return vl[r];
}
int main(){ IOS();
int n, k; cin >> n >> k;
rep(i,1,n){
int u, v; cin >> u >> v; u--, v--;
adj[u].pb(v), adj[v].pb(u);
}
nd.resize(k);
rep(i,0,k) cin >> nd[i], nd[i]--;
dfs(0, 0), decom(0, -1), bfs();
fill(vl, vl + maxn, inf);
rep(i,0,n) assert(dfs3(i) <= lg);
sort(all(nd), [&](int i, int j){ return h[i] > h[j]; });
vector<int> ans;
auto add = [&](int u){
ans.pb(u);
int cr = u;
while(cr + 1){
int d = dist(u, cr);
if(mn[u] >= d) st[cr].insert(mn[u] - d);
cr = nxt[cr];
}
};
auto chk = [&](int u){
int cr = u;
while(cr + 1){
if(st[cr].find(dist(cr, u)) != end(st[cr])) return true;
cr = nxt[cr];
}
return false;
};
auto get = [&](int u){
int k = u;
per(i,lg-1,0) if(h[u] - h[par[k][i]] == mn[par[k][i]]) k = par[k][i];
return k;
};
for(int c: nd){
if(chk(c)) continue;
add(get(c));
}
cout << sz(ans) << '\n'; sort(all(ans));
for(int c: ans) cout << ++c << ' '; cout << '\n';
return 0;
}
Compilation message
pastiri.cpp: In function 'int main()':
pastiri.cpp:130:6: warning: this 'for' clause does not guard... [-Wmisleading-indentation]
130 | for(int c: ans) cout << ++c << ' '; cout << '\n';
| ^~~
pastiri.cpp:130:42: note: ...this statement, but the latter is misleadingly indented as if it were guarded by the 'for'
130 | for(int c: ans) cout << ++c << ' '; cout << '\n';
| ^~~~
pastiri.cpp: In function 'void IOS()':
pastiri.cpp:19:18: warning: ignoring return value of 'FILE* freopen(const char*, const char*, FILE*)' declared with attribute 'warn_unused_result' [-Wunused-result]
19 | freopen("in.in", "r", stdin);
| ~~~~~~~^~~~~~~~~~~~~~~~~~~~~
pastiri.cpp:20:18: warning: ignoring return value of 'FILE* freopen(const char*, const char*, FILE*)' declared with attribute 'warn_unused_result' [-Wunused-result]
20 | freopen("out.out", "w", stdout);
| ~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Runtime error |
215 ms |
524288 KB |
Execution killed with signal 9 |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Execution timed out |
1043 ms |
524288 KB |
Time limit exceeded |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Runtime error |
49 ms |
72040 KB |
Execution killed with signal 6 |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Runtime error |
53 ms |
71964 KB |
Execution killed with signal 6 |
2 |
Halted |
0 ms |
0 KB |
- |