#include <bits/stdc++.h>
#include <unordered_map>
#pragma GCC optimize("O3")
using namespace std;
typedef long ll;
typedef pair<ll, ll> pll;
#define MAX 201010
#define MAXS 18
#define INF 1000000000000000001
#define bb ' '
#define ln '\n'
struct segtree {
ll N;
ll s;
vector<ll> tree, l, r;
void update(ll x, ll a) {
x += s - 1;
tree[x] = a;
x /= 2;
while (x) tree[x] = max(tree[x << 1], tree[(x << 1) + 1]), x >>= 1;
}
ll query(ll low, ll high, ll loc = 1) {
if (l[loc] == low && r[loc] == high) return tree[loc];
if (r[loc << 1] >= high) return query(low, high, loc << 1);
if (l[(loc << 1) + 1] <= low) return query(low, high, (loc << 1) + 1);
return max(query(low, r[loc << 1], loc << 1), query(l[(loc << 1) + 1], high, (loc << 1) + 1));
}
void init(ll x = 1) {
if (x >= s) {
l[x] = r[x] = x - s + 1;
return;
}
init(x * 2);
init(x * 2 + 1);
l[x] = l[x << 1];
r[x] = r[(x << 1) + 1];
}
segtree(ll n) {
N = n;
s = (ll)1 << (ll)ceil(log2(N));
tree.resize(2 * s + 1);
l.resize(2 * s + 1);
r.resize(2 * s + 1);
init();
}
};
vector<ll> adj[MAX], sav[MAX], savdis[MAX], rev[MAX];
ll dir[MAX], ddir[MAX], dddir[MAX];
vector<vector<ll>> chain;
vector<set<ll>> subtree, revtree;
vector<segtree> chainseg;
ll C[MAX], depth[MAX], mxdepv[MAX];
ll mxdep[MAX];
ll sp[MAX][MAXS];
ll cnt;
ll num[MAX];
pll cnum[MAX];
ll ans[MAX];
ll arr[MAX];
ll init(ll x, ll p = 0, ll d = 0) {
sp[x][0] = p;
ll i;
for (i = 1; i < MAXS; i++) sp[x][i] = sp[sp[x][i - 1]][i - 1];
depth[x] = d;
ll sum = 1;
sav[x].resize(adj[x].size());
savdis[x].resize(adj[x].size());
rev[x].resize(adj[x].size());
for (auto v : adj[x]) {
if (v == p) continue;
sum += init(v, x, d + 1);
}
return num[x] = sum;
}
void calc(ll x, ll p = 0) {
ll i;
ll mx = 0;
mxdepv[x] = x;
for (i = 0; i < adj[x].size(); i++) {
if (adj[x][i] == p) continue;
calc(adj[x][i], x);
if (mx < depth[mxdepv[adj[x][i]]]) mx = depth[mxdepv[adj[x][i]]], mxdepv[x] = mxdepv[adj[x][i]];
sav[x][i] = mxdepv[adj[x][i]];
savdis[x][i] = depth[sav[x][i]] - depth[x];
}
ll cnt = 0;
ll vv, vvv;
vv = vvv = -1;
ll nv = -1;
ll nmx = 0;
for (auto v : adj[x]) {
if (v == p) continue;
if (depth[mxdepv[v]] == mx) cnt++, vvv = vv, vv = v;
else if (depth[mxdepv[v]] > nmx) nmx = depth[mxdepv[v]], nv = v;
}
if (cnt >= 2) {
for (i = 0; i < adj[x].size(); i++) {
if (adj[x][i] != p) {
rev[x][i] = (adj[x][i] == vv ? vvv : vv);
}
}
}
else {
for (i = 0; i < adj[x].size(); i++) {
if (adj[x][i] != p) {
rev[x][i] = (adj[x][i] == vv ? nv : vv);
}
}
}
for (i = 0; i < adj[x].size(); i++) {
if (adj[x][i] == p) continue;
if (rev[x][i] == -1) rev[x][i] = x;
else rev[x][i] = mxdepv[rev[x][i]];
}
}
void make_chain(ll x, ll p = 0) {
ll mx, mv;
mx = mv = 0;
chain[cnt].push_back(x);
cnum[x] = { cnt, chain[cnt].size() - 1 };
for (auto v : adj[x]) {
if (v == p) continue;
if (mx < num[v]) mx = num[v], mv = v;
}
if (mv) make_chain(mv, x);
for (auto v : adj[x]) {
if (v == p || v == mv) continue;
cnt++;
chain.push_back(vector<ll>());
make_chain(v, x);
}
}
void make_tree() {
ll i;
for (i = 0; i < chain.size(); i++) chainseg.push_back(segtree(chain[i].size()));
}
void update(ll v, ll x) {
chainseg[cnum[v].first].update(cnum[v].second + 1, x);
}
//1을 루트로 하는 LCA
ll lca(ll u, ll v) {
if (depth[u] != depth[v]) {
if (depth[u] < depth[v]) swap(u, v);
ll i;
for (i = MAXS - 1; i >= 0; i--) if (depth[sp[u][i]] >= depth[v]) u = sp[u][i];
}
if (u == v) return u;
ll i;
for (i = MAXS - 1; i >= 0; i--) if (sp[u][i] != sp[v][i]) u = sp[u][i], v = sp[v][i];
return sp[v][0];
}
//HLD query
ll mxval(ll u, ll v) {
ll ans = 0;
ll l = lca(u, v);
while (cnum[u].first != cnum[l].first) ans = max(ans, chainseg[cnum[u].first].query(1, cnum[u].second + 1)), u = sp[chain[cnum[u].first][0]][0];
while (cnum[v].first != cnum[l].first) ans = max(ans, chainseg[cnum[v].first].query(1, cnum[v].second + 1)), v = sp[chain[cnum[v].first][0]][0];
ans = max(ans, chainseg[cnum[l].first].query(cnum[l].second + 1, cnum[u].second + 1));
ans = max(ans, chainseg[cnum[l].first].query(cnum[l].second + 1, cnum[v].second + 1));
return ans;
}
//두 정점 사이 거리
ll dis(ll u, ll v) { return depth[u] + depth[v] - 2 * depth[lca(u, v)]; }
ll dis(ll u, ll v, ll l) { return depth[u] + depth[v] - 2 * depth[l]; }
//r이 루트, v의 x번째 부모
ll prtx(ll r, ll v, ll x) {
if (x == 0) return v;
ll l = lca(r, v);
ll rv = dis(r, v, l);
if (rv < x) return 0;
if (dis(l, v, l) < x) {
ll d = rv - x;
ll i;
for (i = MAXS - 1; i >= 0; i--) if (d - (1 << i) >= 0) d -= (1 << i), r = sp[r][i];
return r;
}
else {
ll i;
for (i = MAXS - 1; i >= 0; i--) if (x - (1 << i) >= 0) x -= (1 << i), v = sp[v][i];
return v;
}
}
ll getfar(ll v, ll ban) {
if (dir[v] != ban) return dir[v];
return ddir[v];
}
ll getfar(ll v, ll ban1, ll ban2) {
if (ban1 > ban2) swap(ban1, ban2);
if (ban2 == -1) return dir[v];
if (ban1 == -1) return getfar(v, ban2);
if (dir[v] != ban1 && dir[v] != ban2) return dir[v];
if (ddir[v] != ban1 && ddir[v] != ban2) return ddir[v];
return dddir[v];
}
ll getind(vector<ll>& v, ll c) {
return lower_bound(v.begin(), v.end(), c) - v.begin();
}
//r1 : previous root, adj[r1][ind]=r2
void prop(ll r1, ll r2, ll ind) {
if (ddir[r2] == -1) arr[r2] = savdis[r2][dir[r2]], update(r2, arr[r2]);
else arr[r2] = savdis[r2][dir[r2]] + savdis[r2][ddir[r2]], update(r2, arr[r2]);
ll f1 = getfar(r1, ind);
ll f2 = getfar(r1, ind, f1);
if (f1 == -1) arr[r1] = 0, update(r1, 0);
else if (f2 == -1) arr[r1] = savdis[r1][f1], update(r1, arr[r1]);
else arr[r1] = savdis[r1][f1] + savdis[r1][f2], update(r1, arr[r1]);
}
void dfs(ll x, ll p = 0) {
ll i;
for (i = 0; i < adj[x].size(); i++) {
ll fardir = getfar(adj[x][i], getind(adj[adj[x][i]], x));
ll farv = sav[x][i];
if (mxval(farv, adj[x][i]) >= savdis[x][i]) continue;
ll xx = (savdis[x][i] - 1) / 2;
ll root = prtx(x, farv, xx);
if (lca(root, x) == root) revtree[prtx(x, root, 1)].insert(C[x]);
else subtree[root].insert(C[x]);
}
ll v;
for (i = 0; i < adj[x].size(); i++) {
v = adj[x][i];
if (v == p) continue;
ll p1, p2;
p1 = arr[x];
p2 = arr[v];
prop(x, v, i);
dfs(v, x);
arr[x] = p1;
arr[v] = p2;
update(x, arr[x]);
update(v, arr[v]);
}
}
ll mp[MAX];
ll anscnt;
void getans(ll x, ll p = 0) {
for (auto c : subtree[x]) {
if (!mp[c]) anscnt++;
mp[c]++;
}
for (auto c : revtree[x]) {
mp[c]--;
if (!mp[c]) anscnt--;
}
ans[x] = anscnt;
for (auto v : adj[x]) {
if (v == p) continue;
getans(v, x);
}
for (auto c : revtree[x]) {
if (!mp[c]) anscnt++;
mp[c]++;
}
for (auto c : subtree[x]) {
mp[c]--;
if (!mp[c]) anscnt--;
}
}
void calcp(ll x, ll p = 0) {
if (x != 1) {
ll tmp = getind(adj[x], p);
if (p == 1) sav[x][tmp] = rev[p][getind(adj[p], x)];
else {
ll v1 = rev[p][getind(adj[p], x)];
ll v2 = sav[p][getind(adj[p], sp[p][0])];
if (v1 > v2) swap(v1, v2);
if (dis(x, v1) >= dis(x, v2)) sav[x][tmp] = v1;
else sav[x][tmp] = v2;
}
savdis[x][tmp] = dis(x, sav[x][tmp]);
}
for (auto v : adj[x]) if (v != p) calcp(v, x);
}
signed main() {
ios::sync_with_stdio(false), cin.tie(0);
depth[0] = -1;
ll N, M;
cin >> N >> M;
ll i, j;
ll a, b;
for (i = 1; i < N; i++) cin >> a >> b, adj[a].push_back(b), adj[b].push_back(a);
for (i = 1; i <= N; i++) cin >> C[i];
for (i = 1; i <= N; i++) sort(adj[i].begin(), adj[i].end());
init(1);
calc(1);
calcp(1);
cnt = 0;
chain.push_back(vector<ll>());
make_chain(1);
make_tree();
//400ms
for (i = 1; i <= N; i++) {
ll mx = 0;
dir[i] = ddir[i] = dddir[i] = -1;
for (j = 0; j < adj[i].size(); j++) if (mx < savdis[i][j]) mx = savdis[i][j], dir[i] = j;
mx = 0;
for (j = 0; j < dir[i]; j++) if (mx < savdis[i][j]) mx = savdis[i][j], ddir[i] = j;
for (j = max((ll)0, dir[i] + 1); i < adj[i].size(); j++) if (mx < savdis[i][j]) mx = savdis[i][j], ddir[i] = j;
mx = 0;
for (j = 0; j < adj[i].size(); j++) {
if (j == dir[i] || j == ddir[i]) continue;
if (mx < savdis[i][j]) mx = savdis[i][j], dddir[i] = j;
}
if (i != 1) {
ll p = getind(adj[i], sp[i][0]);
ll r = getfar(i, p);
ll rr = getfar(i, p, r);
ll xx = 0;
if (r != -1) xx += savdis[i][r];
if (rr != -1) xx += savdis[i][rr];
arr[i] = xx;
update(i, xx);
}
else {
ll dd;
dd = ddir[1];
if (dd == -1) arr[1] = depth[sav[1][dir[1]]];
else arr[1] = depth[sav[1][dd]] + depth[sav[1][dir[1]]];
update(1, arr[1]);
}
}
subtree.resize(N + 1);
revtree.resize(N + 1);
dfs(1);
for (i = 1; i <= N; i++) for (auto c : revtree[i]) if (!(mp[c]++)) anscnt++;
getans(1);
for (i = 1; i <= N; i++) cout << ans[i] << ln;
}
Compilation message
joi2019_ho_t5.cpp: In function 'void calc(ll, ll)':
joi2019_ho_t5.cpp:79:16: warning: comparison of integer expressions of different signedness: 'll' {aka 'long int'} and 'std::vector<long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
79 | for (i = 0; i < adj[x].size(); i++) {
| ~~^~~~~~~~~~~~~~~
joi2019_ho_t5.cpp:97:17: warning: comparison of integer expressions of different signedness: 'll' {aka 'long int'} and 'std::vector<long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
97 | for (i = 0; i < adj[x].size(); i++) {
| ~~^~~~~~~~~~~~~~~
joi2019_ho_t5.cpp:104:17: warning: comparison of integer expressions of different signedness: 'll' {aka 'long int'} and 'std::vector<long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
104 | for (i = 0; i < adj[x].size(); i++) {
| ~~^~~~~~~~~~~~~~~
joi2019_ho_t5.cpp:110:16: warning: comparison of integer expressions of different signedness: 'll' {aka 'long int'} and 'std::vector<long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
110 | for (i = 0; i < adj[x].size(); i++) {
| ~~^~~~~~~~~~~~~~~
joi2019_ho_t5.cpp: In function 'void make_tree()':
joi2019_ho_t5.cpp:135:16: warning: comparison of integer expressions of different signedness: 'll' {aka 'long int'} and 'std::vector<std::vector<long int> >::size_type' {aka 'long unsigned int'} [-Wsign-compare]
135 | for (i = 0; i < chain.size(); i++) chainseg.push_back(segtree(chain[i].size()));
| ~~^~~~~~~~~~~~~~
joi2019_ho_t5.cpp: In function 'void dfs(ll, ll)':
joi2019_ho_t5.cpp:210:16: warning: comparison of integer expressions of different signedness: 'll' {aka 'long int'} and 'std::vector<long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
210 | for (i = 0; i < adj[x].size(); i++) {
| ~~^~~~~~~~~~~~~~~
joi2019_ho_t5.cpp:211:6: warning: unused variable 'fardir' [-Wunused-variable]
211 | ll fardir = getfar(adj[x][i], getind(adj[adj[x][i]], x));
| ^~~~~~
joi2019_ho_t5.cpp:220:16: warning: comparison of integer expressions of different signedness: 'll' {aka 'long int'} and 'std::vector<long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
220 | for (i = 0; i < adj[x].size(); i++) {
| ~~^~~~~~~~~~~~~~~
joi2019_ho_t5.cpp: In function 'int main()':
joi2019_ho_t5.cpp:295:17: warning: comparison of integer expressions of different signedness: 'll' {aka 'long int'} and 'std::vector<long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
295 | for (j = 0; j < adj[i].size(); j++) if (mx < savdis[i][j]) mx = savdis[i][j], dir[i] = j;
| ~~^~~~~~~~~~~~~~~
joi2019_ho_t5.cpp:298:38: warning: comparison of integer expressions of different signedness: 'll' {aka 'long int'} and 'std::vector<long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
298 | for (j = max((ll)0, dir[i] + 1); i < adj[i].size(); j++) if (mx < savdis[i][j]) mx = savdis[i][j], ddir[i] = j;
| ~~^~~~~~~~~~~~~~~
joi2019_ho_t5.cpp:300:17: warning: comparison of integer expressions of different signedness: 'll' {aka 'long int'} and 'std::vector<long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
300 | for (j = 0; j < adj[i].size(); j++) {
| ~~^~~~~~~~~~~~~~~
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
11 ms |
19276 KB |
Output is correct |
2 |
Runtime error |
37 ms |
40588 KB |
Execution killed with signal 11 |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Runtime error |
452 ms |
146708 KB |
Execution killed with signal 11 |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Runtime error |
606 ms |
186604 KB |
Execution killed with signal 11 |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
11 ms |
19276 KB |
Output is correct |
2 |
Runtime error |
37 ms |
40588 KB |
Execution killed with signal 11 |
3 |
Halted |
0 ms |
0 KB |
- |