#include <bits/stdc++.h>
#define pb push_back
#define SS ios_base::sync_with_stdio(0);cin.tie(nullptr);cout.tie(nullptr);
// #define int long long
#define all(v) v.begin(),v.end()
using namespace std;
const int N = 1e5 + 17, MX = 4e5 + 177, inf = 1e9 + 1, mod = 998244353;
int n, k;
int K;
vector<int> g[N], e[N];
int a[N];
int cnt[N], in[N], tin[N], tout[N], timer, sz[N], boss[N], mark[N], dep[N], ver[N];
map<int, int> col[N];
int ans, sw;
void calc(int v){
tin[v] = ++timer;
ver[timer] = v;
col[a[v]][dep[v]]++;
sz[v] = 1;
for(int to : g[v]){
dep[to] = dep[v] + 1;
calc(to);
sz[v] += sz[to];
if(!boss[v] || sz[boss[v]] < sz[to]) boss[v] = to;
}
tout[v] = timer;
}
void f(int v, int c){
cnt[dep[v]]++;
if(a[v] == c) in[dep[v]]++;
for(int to : g[v]){
f(to, c);
}
}
void dfs1(int v, int c){
if(a[v] == c){
f(v, c);
int cur = 0, cursw = 0, i = dep[v];
while(cnt[i]){
cur += in[i];
cur += min(col[c][i], cnt[i]) - in[i];
cursw += min(col[c][i], cnt[i]) - in[i];
cnt[i] = in[i] = 0;
i++;
}
// if(a[v] == 3) cout << cur;
if(ans < cur || (ans == cur && cursw < sw)){
ans = cur;
sw = cursw;
}
return;
}
for(int to : g[v]){
dfs1(to, c);
}
}
bool ch(int u, int v){
return (tin[u] <= tin[v] && tout[u] >= tout[v]);
}
void dfs2(int v, int tp){
for(int to : g[v]){
if(boss[v] == to) continue;
dfs2(to, 0);
}
if(boss[v]) dfs2(boss[v], 1);
cnt[dep[v]]++;
for(int to : g[v]){
if(boss[v] == to) continue;
for(int i = tin[to]; i <= tout[to]; i++){
cnt[dep[ver[i]]]++;
}
}
if(mark[a[v]]){
unordered_set<int> st;
for(int x : e[a[v]]){
if(ch(v, x)) in[dep[x]]++;
st.insert(dep[x]);
}
int cur = 0, cursw = 0;
for(int d : st){
cur += in[d];
cur += min(col[a[v]][d], cnt[d]) - in[d];
cursw += min(col[a[v]][d], cnt[d]) - in[d];
// if(a[v] == 2) cout << in[d] << ' ' << d << '\n';
in[d] = 0;
}
// cout << a[v] << ' ' << cursw << '\n';
if(cur > ans || (ans == cur && cursw < sw)){
ans = cur;
sw = cursw;
}
}
if(tp == 0){
int i = dep[v];
while(cnt[i]){
cnt[i] = 0;
i++;
}
}
}
void solve(){
cin >> n >> k;
K = sqrt(n);
for(int i = 1; i <= n; i++){
cin >> a[i];
a[i]++;
e[a[i]].push_back(i);
}
for(int i = 2; i <= n; i++){
int p;
cin >> p;
p++;
g[p].push_back(i);
}
calc(1);
vector<int> heavy;
for(int i = 1; i <= k; i++){
if(e[i].size() > K) heavy.push_back(i);
else mark[i] = 1;
}
for(int i : heavy){
dfs1(1, i);
}
dfs2(1, 1);
cout << ans << ' ' << sw;
}
signed main(){
SS
// freopen("trains.in", "r", stdin);
// freopen("trains.out", "w", stdout);
int t = 1;
// cin >> t;
while(t--){
solve();
}
}
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |