#include<bits/stdc++.h>
#define pb emplace_back
using namespace std;
using ll = long long;
const int maxn = 300010;
int n, m;
int res[maxn], c[maxn], dep[maxn];
vector<int> edge[maxn];
namespace {
int _res, cnt[maxn];
void put_in(int a, int d) {
//cerr << (d>0 ? "put in : " : "get out : " ) << ' ' << a << '\n';
if (!cnt[c[a]]) ++_res, assert( d > 0);
if ((cnt[c[a]]+=d) == 0) --_res;
//cerr << "res : " << _res << '\n';
}
int get_res() {
return _res;
}
pair<int,int> ans;
void _dfs(int now, int last = -1) {
static int d;
++d;
for (int u : edge[now]) if (u != last)
_dfs(u, now);
ans = max(ans, {d, now});
--d;
}
int getfar(int now) {
ans = {-1, -1};
_dfs(now);
return ans.second;
}
};
int len[maxn];
void get_dep_len(int now, int last = -1) {
static int d;
len[now] = 0;
dep[now] = ++d;
for (int u : edge[now]) if (u != last) {
get_dep_len(u, now);
len[now] = max(len[u]+1, len[now]);
}
--d;
}
void get_all(int now, int last = -1) {
static vector<int> st;
auto ost = [&](){
return;
cerr << "now stack " << now << '\n';
cerr << "all stack : ";
for (int u : st)cerr << u << ' ';
cerr << '\n';
};
//cerr << "now " << now << '\n';
int p = -1, up = -1, sec = -1;
for (int u : edge[now]) if (u != last)
if (len[u]+1 > up) sec = up, up = len[u]+1, p = u;
else sec = max(sec, len[u]+1);
if (p == -1) {
ost();
res[now] = max(res[now], get_res());
return;
}
while (st.size() && dep[now] - dep[st.back()] <= sec) {
put_in(st.back(), -1);
//cerr << "rad is " << dep[now]-dep[st.back()] <<' ' << "lim is " << sec;
st.pop_back();
}
st.pb(now);
put_in(now, 1);
//cerr << "nxt is big!\n";
get_all(p, now);
while (st.size() && dep[now] - dep[st.back()] <= up) {
put_in(st.back(), -1);
st.pop_back();
}
st.pb(now);
put_in(now, 1);
for (int u : edge[now]) if (u != last && u != p)
get_all(u, now);
if (st.size() && st.back() == now) {
st.pop_back();
put_in(now, -1);
}
ost();
//cerr << " query " << get_res() << '\n';
res[now] = max(res[now], get_res());
}
void debug() {
cerr << "len\n";
for (int i = 1;i <= n;++i)
cerr << len[i] << " \n"[i==n];
cerr << "dep\n";
for (int i = 1;i <= n;++i)
cerr << dep[i] << " \n"[i==n];
}
signed main(){
ios_base::sync_with_stdio(0), cin.tie(0);
cin >> n >> m;
for (int a, b, i = 1;i < n;++i) {
cin >> a >> b;
edge[a].pb(b);
edge[b].pb(a);
}
for (int i = 1;i <= n;++i)
cin >> c[i];
int a = getfar(1), b = getfar(a);
get_dep_len(a), get_all(a);
//debug();
//cerr << '\n';
get_dep_len(b), get_all(b);
//debug();
//cerr << '\n';
for (int i = 1;i <= n;++i)
cout << res[i] << '\n';
}
Compilation message
joi2019_ho_t5.cpp: In function 'void get_all(int, int)':
joi2019_ho_t5.cpp:57:29: warning: suggest explicit braces to avoid ambiguous 'else' [-Wdangling-else]
for (int u : edge[now]) if (u != last)
^
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
8 ms |
7424 KB |
Output is correct |
2 |
Incorrect |
10 ms |
7552 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
177 ms |
13136 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
270 ms |
16068 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
8 ms |
7424 KB |
Output is correct |
2 |
Incorrect |
10 ms |
7552 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |