#include <bits/stdc++.h>
#define pb push_back
#define f first
#define sc second
using namespace std;
typedef long long int ll;
typedef string str;
const int k = 19;
const int inf = 1e8;
int n, m;
vector<int> C;
vector<vector<int>> v;
vector<vector<pair<int, int>>> anc;
vector<int> cnt;
vector<int> mxh;
vector<int> ans;
pair<int, int> next(int nd, int h){
if(h == 0) return {nd, 0};
for(int i = 0; i < k; i++){
if(anc[nd][i].sc <= h){
h-=anc[nd][i].sc;
auto nxt = next(anc[nd][i].f, h);
nxt.sc+=anc[nd][i].sc;
return nxt;
}
}
return {anc[nd][0].f, anc[nd][0].sc};
}
void dfs0(int nd, int ss){
for(int x: v[nd]) if(x != ss) dfs0(x, nd);
int mx = 0, smx = 0, in = n;
for(int x: v[nd]) if(x != ss){
if(mxh[x] > mx) smx = mx, mx = mxh[x], in = x;
else if(mxh[x] > smx) smx = mxh[x];
}
mxh[nd] = mxh[in]+1;
auto nxt = next(in, smx);
nxt.sc++;
nxt.sc = min(nxt.sc, inf);
anc[nd][0] = nxt;
for(int i = 1; i < k; i++){
anc[nd][i].f = anc[anc[nd][i-1].f][i-1].f;
anc[nd][i].sc = anc[nd][i-1].sc + anc[anc[nd][i-1].f][i-1].sc;
anc[nd][i].sc = min(anc[nd][i].sc, inf);
}
cnt[nd] = cnt[nxt.f]+1;
}
void dfs1(int nd, int ss){
vector<pair<int, int>> sth;
for(int x: v[nd]) sth.pb({mxh[x], x});
sort(sth.rbegin(), sth.rend());
while(sth.size() < 3) sth.pb({0, n});
auto nxt = next(sth[0].sc, sth[1].f);
ans[nd] = cnt[nxt.f];
auto old_anc = anc[nd];
int old_cnt = cnt[nd], old_mxh = mxh[nd];
for(int x: v[nd]) if(x != ss){
pair<int, int> nxt;
if(x == sth[0].sc) nxt = next(sth[1].sc, sth[2].f), mxh[nd] = mxh[sth[1].sc]+1;
else if(x == sth[1].sc) nxt = next(sth[0].sc, sth[2].f), mxh[nd] = mxh[sth[0].sc]+1;
else nxt = next(sth[0].sc, sth[1].f), mxh[nd] = mxh[sth[0].sc]+1;
nxt.sc++;
nxt.sc = min(nxt.sc, inf);
anc[nd][0] = nxt;
for(int i = 1; i < k; i++){
anc[nd][i].f = anc[anc[nd][i-1].f][i-1].f;
anc[nd][i].sc = anc[nd][i-1].sc + anc[anc[nd][i-1].f][i-1].sc;
anc[nd][i].sc = min(anc[nd][i].sc, inf);
}
cnt[nd] = cnt[nxt.f]+1;
dfs1(x, nd);
}
anc[nd] = old_anc;
cnt[nd] = old_cnt;
mxh[nd] = old_mxh;
}
int main(){
ios_base::sync_with_stdio(0);
cin.tie(0);
cin >> n >> m;
C.resize(n);
v.resize(n);
anc.resize(n+1, vector<pair<int, int>>(k, {n, inf}));
cnt.resize(n+1, 0);
mxh.resize(n+1, 1);
mxh[n] = 0;
ans.resize(n);
for(int i = 0; i < n-1; i++){
int a, b; cin >> a >> b; a--, b--;
v[a].pb(b);
v[b].pb(a);
}
for(int &x: C) cin >> x;
dfs0(0, -1);
dfs1(0, -1);
for(int x: ans) cout << x << "\n";
}
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Correct |
0 ms |
320 KB |
Output is correct |
2 |
Incorrect |
2 ms |
720 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Incorrect |
201 ms |
31136 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Correct |
323 ms |
43204 KB |
Output is correct |
2 |
Execution timed out |
2067 ms |
58864 KB |
Time limit exceeded |
3 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Correct |
0 ms |
320 KB |
Output is correct |
2 |
Incorrect |
2 ms |
720 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |