#include <bits/stdc++.h>
#pragma GCC optimize ("Ofast,unroll-loops")
#pragma GCC target ("avx2")
using namespace std;
typedef long long ll;
typedef pair<int, int> pp;
#define er(args ...) cerr << __LINE__ << ": ", err(new istringstream(string(#args)), args), cerr << endl
#define per(i,r,l) for(int i = (r); i >= (l); i--)
#define rep(i,l,r) for(int i = (l); i < (r); i++)
#define all(x) begin(x), end(x)
#define sz(x) (int)(x).size()
#define pb push_back
#define ss second
#define ff first
void err(istringstream *iss){}template<typename T,typename ...Args> void err(istringstream *iss,const T &_val, const Args&...args){string _name;*iss>>_name;if(_name.back()==',')_name.pop_back();cerr<<_name<<" = "<<_val<<", ",err(iss,args...);}
void IOS(){
cin.tie(0) -> sync_with_stdio(0);
// #ifndef ONLINE_JUDGE
// freopen("in.in", "r", stdin);
// freopen("out.out", "w", stdout);
// #endif
}
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
const ll mod = 1e9 + 7, maxn = 5e5 + 5, lg = 22, inf = ll(1e9) + 5;
ll pw(ll a,ll b,ll md=mod){if(!b)return 1;ll k=pw(a,b>>1ll);return k*k%md*(b&1ll?a:1)%md;}
template<class T> struct Rmq{
vector<vector<T>> rmq;
vector<int> lgg;
T comb(T a, T b){
return min(a, b);
}
Rmq(){}
Rmq(vector<T> a){
int n = sz(a);
lgg.resize(n + 1);
rep(i,2,n + 1) lgg[i] = lgg[i>>1] + 1;
rmq.assign(lgg[n] + 1, vector<T>(n));
rmq[0] = a;
rep(i,1,lgg[n] + 1){
rep(j,0,n-(1<<i)+1){
rmq[i][j] = comb(rmq[i-1][j], rmq[i-1][j + (1 << (i-1))]);
}
}
}
T get(int l, int r){ // 0-base [l, r]
int g = lgg[r - l + 1];
return comb(rmq[g][l], rmq[g][r - (1<<g) + 1]);
}
};
Rmq<int> rmq;
vector<int> adj[maxn], nd;
set<int> st[maxn];
int par[maxn][lg], h[maxn], cnt[maxn], nxt[maxn], vl[maxn], mn[maxn], s[maxn], t, v[maxn<<1];
bool ban[maxn];
void dfs(int r, int p){
v[t] = h[r], s[r] = t++, par[r][0] = p;
rep(i,1,lg) par[r][i] = par[par[r][i-1]][i-1];
for(int c: adj[r]) if(c - p) h[c] = h[r] + 1, dfs(c, r), v[t++] = h[r];
}
// int lca(int u, int v){
// if(h[u] > h[v]) swap(u, v);
// per(i,lg-1,0) if(h[par[v][i]] >= h[u]) v = par[v][i];
// if(u == v) return u;
// per(i,lg-1,0) if(par[v][i] - par[u][i]) u = par[u][i], v = par[v][i];
// return par[u][0];
// }
void dfs2(int r, int p){
cnt[r] = 1;
for(int c: adj[r]) if(c - p && !ban[c]) dfs2(c, r), cnt[r] += cnt[c];
}
int find_cent(int r, int p, int tot){
for(int c: adj[r]) if(c - p && !ban[c] && (cnt[c]<<1) > tot) return find_cent(c, r, tot);
return r;
}
void decom(int r, int tp){
dfs2(r, -1);
r = find_cent(r, -1, cnt[r]);
nxt[r] = tp, ban[r] = true;
for(int c: adj[r]) if(!ban[c]) decom(c, r);
}
int dist(int u, int v){
if(s[u] > s[v]) swap(u, v);
return h[u] + h[v] - (rmq.get(s[u], s[v])<<1);
}
void bfs(){
fill(mn, mn + maxn, inf);
deque<int> q;
for(int c: nd) mn[c] = 0, q.pb(c);
while(sz(q)){
int r = q.front(); q.pop_front();
for(int c: adj[r]) if(mn[c] == inf) mn[c] = mn[r] + 1, q.pb(c);
}
}
int dfs3(int r){
if(r == -1) return 0;
if(vl[r] == inf) return vl[r] = dfs3(nxt[r]) + 1;
return vl[r];
}
int main(){ IOS();
int n, k; cin >> n >> k;
rep(i,1,n){
int u, v; cin >> u >> v; u--, v--;
adj[u].pb(v), adj[v].pb(u);
}
nd.resize(k);
rep(i,0,k) cin >> nd[i], nd[i]--;
dfs(0, 0), decom(0, -1), bfs();
rmq = Rmq(vector<int>(v, v + t));
fill(vl, vl + maxn, inf);
rep(i,0,n) assert(dfs3(i) <= lg);
sort(all(nd), [&](int i, int j){ return h[i] > h[j]; });
vector<int> ans;
auto add = [&](int u){
ans.pb(u);
int cr = u;
while(cr + 1){
int d = dist(u, cr);
if(mn[u] >= d) st[cr].insert(mn[u] - d);
cr = nxt[cr];
}
};
auto chk = [&](int u){
int cr = u;
while(cr + 1){
if(st[cr].find(dist(cr, u)) != end(st[cr])) return true;
cr = nxt[cr];
}
return false;
};
auto get = [&](int u){
int k = u;
per(i,lg-1,0) if(h[u] - h[par[k][i]] == mn[par[k][i]]) k = par[k][i];
return k;
};
for(int c: nd){
if(chk(c)) continue;
add(get(c));
}
cout << sz(ans) << '\n'; sort(all(ans));
for(int c: ans) cout << ++c << ' '; cout << '\n';
return 0;
}
Compilation message
pastiri.cpp: In function 'int main()':
pastiri.cpp:160:6: warning: this 'for' clause does not guard... [-Wmisleading-indentation]
160 | for(int c: ans) cout << ++c << ' '; cout << '\n';
| ^~~
pastiri.cpp:160:42: note: ...this statement, but the latter is misleadingly indented as if it were guarded by the 'for'
160 | for(int c: ans) cout << ++c << ' '; cout << '\n';
| ^~~~
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
366 ms |
237640 KB |
Output is correct |
2 |
Correct |
410 ms |
237532 KB |
Output is correct |
3 |
Correct |
442 ms |
237616 KB |
Output is correct |
4 |
Correct |
703 ms |
251260 KB |
Output is correct |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
21 ms |
40828 KB |
Output is correct |
2 |
Correct |
24 ms |
40788 KB |
Output is correct |
3 |
Correct |
767 ms |
198856 KB |
Output is correct |
4 |
Correct |
703 ms |
201168 KB |
Output is correct |
5 |
Execution timed out |
1094 ms |
106600 KB |
Time limit exceeded |
6 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
19 ms |
40020 KB |
Output is correct |
2 |
Correct |
22 ms |
39976 KB |
Output is correct |
3 |
Correct |
20 ms |
40020 KB |
Output is correct |
4 |
Correct |
23 ms |
39892 KB |
Output is correct |
5 |
Correct |
20 ms |
39976 KB |
Output is correct |
6 |
Correct |
19 ms |
40036 KB |
Output is correct |
7 |
Correct |
19 ms |
39892 KB |
Output is correct |
8 |
Correct |
19 ms |
40020 KB |
Output is correct |
9 |
Correct |
18 ms |
40020 KB |
Output is correct |
10 |
Correct |
18 ms |
39948 KB |
Output is correct |
11 |
Correct |
19 ms |
39764 KB |
Output is correct |
12 |
Correct |
18 ms |
39700 KB |
Output is correct |
13 |
Correct |
23 ms |
39892 KB |
Output is correct |
14 |
Correct |
20 ms |
40020 KB |
Output is correct |
15 |
Correct |
23 ms |
39928 KB |
Output is correct |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Execution timed out |
1087 ms |
109484 KB |
Time limit exceeded |
2 |
Halted |
0 ms |
0 KB |
- |