#include <bits/stdc++.h>
using namespace std;
#define pp pair <int, int>
#define fi first
#define se second
const int maxn = 3e5 + 9, blocksz = 320;
int n, k, c[maxn], h[maxn], sz[maxn];
int in[maxn], out[maxn], timedfs = 0;
vector<int> adj[maxn];
deque<int> pos[maxn];
vector<int> large, small;
int mark[maxn];
pp res = {1, 0};
pp cmp(pp A, pp B){
if(A.fi > B.fi) return A;
if(A.fi < B.fi) return B;
if(A.se <= B.se) return A;
return B;
}
void dfs(int u){
in[u] = ++timedfs;
for(int v: adj[u]){
h[v] = h[u] + 1;
dfs(v);
}
out[u] = timedfs;
}
int cur, mx;
int cnt_color[maxn], cnt[maxn], mp[maxn];
pp ans;
void dfs_color(int u){
if(c[u] == cur){
cnt_color[h[u]]++;
cnt[u]++;
}
for(int v: adj[u]){
dfs_color(v);
cnt[u] += cnt[v];
}
}
void dfs4(int u){
if(mp[h[u]] < cnt_color[h[u]]) ans.fi++;
mp[h[u]]++;
mx = max(mx, h[u]);
for(int v: adj[u]) dfs4(v);
}
void dfs3(int u){
if(c[u] == cur){
ans.fi = mx = 0;
dfs4(u);
ans.se = ans.fi - cnt[u];
res = cmp(res, ans);
for(int i = h[u]; i <= mx; i++) mp[i] = 0;
}
else for(int v: adj[u]) dfs3(v);
}
void solve_large(int color){
cur = color;
for(int i = 1; i <= n; i++) cnt[i] = cnt_color[i] = 0;
dfs_color(1);
dfs3(1);
}
vector<int> l[maxn];
void dfs_merge(int u){
for(int v: adj[u]){
dfs_merge(v);
if(pos[v].size() > pos[u].size()) swap (pos[u], pos[v]);
for(int z = 0; z < pos[v].size(); z++) pos[u][z] += pos[v][z];
pos[v].clear();
}
pos[u].push_front(1);
ans.fi = ans.se = 0;
for(int i = 0; i < l[c[u]].size(); i++){
int x = l[c[u]][i];
if (h[x] < h[u]) continue;
int cnt_node = 1, j = i + 1;
int in_subtree = (in[u] <= in[x] && in[x] <= out[u]);
while(j < l[c[u]].size() && h[l[c[u]][j]] == h[x]){
cnt_node++;
if(in[u] <= in[l[c[u]][j]] && out[l[c[u]][j]] <= out[u]) in_subtree++;
j++;
}
if(h[x] - h[u] < pos[u].size ()){
ans.fi += min(cnt_node, pos[u][h[x] - h[u]]);
ans.se += min(cnt_node, pos[u][h[x] - h[u]]) - in_subtree;
}
i = j - 1;
}
res = cmp(res, ans);
}
bool cmp_depth(int u, int v){
return h[u] < h[v];
}
signed main(){
ios_base::sync_with_stdio(false);
cin.tie(0); cout.tie(0);
cin >> n >> k;
for (int i = 1; i <= n; i++){
cin >> c[i];
sz[c[i]]++;
}
for (int i = 0; i < k; i++){
if(sz[i] >= blocksz) large.push_back(i);
else if(sz[i]){
small.push_back(i);
mark[i]++;
}
}
for (int i = 2; i <= n; i++){
int p; cin >> p;
adj[p + 1].push_back(i);
}
dfs(1);
for(int i: large) solve_large(i);
for(int i = 1; i <= n; i++){
if(mark[c[i]] == 0) continue;
l[c[i]].push_back(i);
}
for(int i = 0; i < k; i++) sort(l[i].begin(), l[i].end(), cmp_depth);
dfs_merge(1);
cout << res.fi << " " << res.se;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |