/*
بسم الله الرحمن الرحيم
Author:
(:Muhammad Aneeq:)
*/
#pragma GCC optimize("O2")
#include <iostream>
#include <map>
#include <algorithm>
#include <queue>
#include <vector>
#warning check the output
using namespace std;
int const N=1e5+10,S=300;
vector<int>nei[N]={};
vector<int>lv[N]={};
map<int,vector<int>>co[N]={};
int it[N],ot[N];
int cnt[N]={};
int dp[N]={};
int c[N]={};
int dep[N];
int tm=0;
map<int,int>*vl[N]={};
int sz[N]={};
void dfs(int u,int h=0)
{
dep[u]=h;
co[c[u]][h].push_back(u);
it[u]=tm++;
sz[u]=1;
lv[h].push_back(it[u]);
for(auto i:nei[u])
dfs(i,h+1),sz[u]+=sz[i];
ot[u]=tm-1;
}
int ans=1,mx=0;
void dfs1(int u)
{
int x=-1;
for (auto i:nei[u])
{
dfs1(i);
if (x==-1)
x=i;
if (sz[i]>sz[x])
x=i;
}
if (x!=-1)
vl[u]=vl[x];
else
vl[u]=new map<int,int> ();
(*vl[u])[c[u]]++;
for (auto i:nei[u])
{
if (i==x) continue;
for (auto [j,k]:*(vl[i]))
(*vl[u])[j]+=k;
}
if (dp[u]>ans)
{
ans=dp[u];
mx=dp[u]-(*vl[u])[c[u]];
}
if (dp[u]==ans)
mx=min(mx,dp[u]-(*vl[u])[c[u]]);
}
int fn(int u,int lev)
{
int z=lower_bound(begin(lv[lev]),end(lv[lev]),it[u])-begin(lv[lev]);
int y=upper_bound(begin(lv[lev]),end(lv[lev]),ot[u])-begin(lv[lev]);
return y-z;
}
void bfs(int s)
{
queue<int>S;
S.push(s);
int col=c[s];
map<int,int>lev;
while (S.size())
{
int f=S.front();
S.pop();
lev[dep[f]]++;
for (auto i:nei[f])
S.push(i);
}
int ans=0;
for (auto [i,cnnvj]:lev)
{
ans+=min(cnnvj,int(co[col][i].size()));
}
dp[s]=ans;
}
int check(int k,int u)
{
int ans=0,mx=0;
for (auto [i,inds]:co[k])
{
int sz=fn(u,i);
sz=min(sz,int(inds.size()));
ans+=sz;
}
return ans;
}
void dfs2(int u,int cl)
{
if (c[u]==cl)
{
bfs(u);return;
}
for (auto i:nei[u])
dfs2(i,cl);
}
inline void solve()
{
int n,k;
cin>>n>>k;
for (int i=0;i<n;i++)
{
cin>>c[i];
cnt[c[i]]++;
}
int pra;
bool subt1=1;
for (int i=1;i<n;i++)
{
cin>>pra;
if (pra!=i-1)
subt1=0;
nei[pra].push_back(i);
}
dfs(0);
for (int i=0;i<n;i++)
{
if (cnt[c[i]]<=S)
dp[i]=check(c[i],i);
}
for (int i=0;i<k;i++)
{
if (cnt[i]>S)
dfs2(0,i);
}
dfs1(0);
cout<<ans<<' '<<mx<<endl;
}
int main()
{
ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL);
int t=1;
for (int i=1;i<=t;i++)
{
solve();
}
}
Compilation message (stderr)
Main.cpp:12:2: warning: #warning check the output [-Wcpp]
12 | #warning check the output
| ^~~~~~~
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |