#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 2e5 + 12, MOD = (int)1e9 + 7;
int n, m, a[N];
bool vis[N], timer;
vector<int> g[N];
vector<vector<int>> dist;
bool zap;
int f(int d) {
vector<int> D(n + 1, 0);
timer = 1;
for(int i = 1; i <= n; i++) {
vis[i] = 0;
}
vis[d] = timer;
int lst = d;
queue<int> q;
q.push(d);
while(!q.empty()) {
int v = q.front();
lst = v;
q.pop();
for(int to : g[v]) {
if(vis[to] != timer) {
D[to] = D[v] + 1;
q.push(to);
vis[to] = timer;
}
}
}
dist.push_back(D);
return lst;
}
int dep[N], mxd[N], res[N], d, d1, o, cr, gl[N];
void dfs(int v, int pr = -1) {
mxd[v] = dep[v];
for(int to:g[v]) if(to != pr) {
dep[to] = dep[v] + 1;
dfs(to, v);
mxd[v] = max(mxd[v], mxd[to]);
}
for(int i = 0; i < (int)g[v].size(); i++) {
if(g[v][i] == pr) {
swap(g[v][(int)g[v].size() - 1], g[v][i]);
break;
}
}
for(int i = 0; i < (int)g[v].size() - 1; i++) {
if(mxd[g[v][i]] == mxd[v]) {
swap(g[v][0], g[v][i]);
}
}
}
int b = 17, it = 0, cur = 0, ver[N];
int up[N][18];
struct node{
node *l = 0, *r = 0;
bool sum = 0;
node(){};
node(bool v) {
sum = v;
}
node (node *L, node *R) {
l = L;
r = R;
// sum = l->sum + r->sum;
}
};
using pnode = node *;
pnode tr[N];
pnode build(int tl = 1, int tr = m) {
if(tl == tr) {
return new node();
}
int tm = (tl + tr) >> 1;
return new node(build(tl, tm), build(tm + 1, tr));
}
pnode upd(int pos, pnode v, int tl = 1, int tr = m) {
if(tl == tr) {
return new node(1);
}
int tm = (tl + tr) >> 1;
if(pos <= tm)
return new node(upd(pos, v->l, tl, tm), v->r);
return new node(v->l, upd(pos, v->r, tm + 1, tr));
}
bool get(int pos, pnode v, int tl = 1, int tr = m) {
if(tl == tr) return v->sum;
int tm = (tl + tr) >> 1;
if(pos <= tm) return get(pos, v->l, tl, tm);
return get(pos, v->r, tm + 1, tr);
}
int rt;
vector<int> e[N];
void bld(int v) {
int sz = (int)g[v].size() - (rt != v), mx1 = dep[v];
if(!sz) return;
for(int i = 1; i < sz; i++) {
mx1 = max(mx1, mxd[g[v][i]]);
}
int bf = cur;
for(int i = 0; i < sz; i++) {
int to = g[v][i];
int mx = dep[v];
if(i) mx = mxd[g[v][0]];
if(i < sz - 1) mx = max(mx, mx1);
mx -= dep[v];
if(i == 1 && ver[cur] == v) {
cur = up[cur][0];
}
if(i <= 1 && dep[v] - dep[ver[cur]] <= mx) {
for(int i = b - 1; i >= 0; i--) {
int nv = up[cur][i];
if(dep[v] - dep[ver[nv]] <= mx) {
cur = nv;
}
}
cur = up[cur][0];
}
if(i <= 1 && !get(a[v], tr[cur])) {
it++;
e[cur].push_back(it);
ver[it] = v;
gl[it] = gl[cur] + 1;
up[it][0] = cur;
tr[it] = upd(a[v], tr[cur]);
for(int i = 1; i < b; i++) {
up[it][i] = up[up[it][i - 1]][i - 1];
}
cur = it;
}
bld(to);
}
cur = bf;
}
void go(int v) {
if(dist[cr][v] > dist[o][v] || (dist[cr][v] == dist[o][v] && zap)) {
int f = cur, val = -mxd[v] + dep[v] * 2;
set<int> r;
if(dep[ver[f]] >= val) {
for(int i = b - 1; i >= 0; i--) {
int nv = ver[up[f][i]];
if(dep[nv] >= val) {
f = up[f][i];
}
}
f = up[f][0];
}
res[v] += gl[f];
}
int sz = (int)g[v].size() - (rt != v), mx1 = dep[v];
if(!sz) return;
for(int i = 1; i < sz; i++) {
mx1 = max(mx1, mxd[g[v][i]]);
}
int bf = cur;
for(int i = 0; i < sz; i++) {
int to = g[v][i];
int mx = dep[v];
if(i) mx = mxd[g[v][0]];
if(i < sz - 1) mx = max(mx, mx1);
mx -= dep[v];
if(i == 1) {
cur = up[cur][0];
}
if(i <= 1 && dep[v] - dep[ver[cur]] <= mx) {
for(int i = b - 1; i >= 0; i--) {
int nv = up[cur][i];
if(dep[v] - dep[ver[nv]] <= mx) {
cur = nv;
}
}
cur = up[cur][0];
}
if(i <= 1) {
it++;
ver[it] = v;
auto check = [&](){
int f = cur;
while(f) {
if(a[ver[f]] == a[v]) return 0;
f = up[f][0];
}
return 1;
};
gl[it] = gl[cur] + (check());
up[it][0] = cur;
// tr[it] = upd(a[v], tr[cur]);
for(int i = 1; i < b; i++) {
up[it][i] = up[up[it][i - 1]][i - 1];
}
cur = it;
}
go(to);
}
cur = bf;
}
void solve(int root) {
it = cur = 0;
ver[0] = 0;
dep[0] = -(int)1e9;
// tr[0] = build();
rt = root;
dep[root] = 1;
dfs(root);
bld(root);
it = cur = 0;
go(root);
for(int i = 0; i <= it; i++) {
e[i].clear();
}
}
void test() {
cin >> n >> m;
for(int i = 1; i <= n - 1; i++) {
int a, b;
cin >> a >> b;
g[a].push_back(b);
g[b].push_back(a);
}
for(int i = 1; i <= n; i++) {
cin >> a[i];
}
d = f(1), d1 = f(d);
f(d1);
o = 2;cr = 1;
solve(d);
zap = 1;
o = 1;cr = 2;
solve(d1);
for(int i = 1; i <= n; i++) {
cout << res[i] << '\n';
}
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(0);
int t = 1;
// cin >> t;
while(t--)
test();
}
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Runtime error |
26 ms |
34384 KB |
Execution killed with signal 11 |
2 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Runtime error |
78 ms |
46496 KB |
Execution killed with signal 11 |
2 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Runtime error |
93 ms |
51152 KB |
Execution killed with signal 11 |
2 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Runtime error |
26 ms |
34384 KB |
Execution killed with signal 11 |
2 |
Halted |
0 ms |
0 KB |
- |