# | 제출 시각 | 아이디 | 문제 | 언어 | 결과 | 실행 시간 | 메모리 |
---|---|---|---|---|---|---|---|
1195242 | badge881 | Team Coding (EGOI24_teamcoding) | C++20 | 0 ms | 0 KiB |
#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 10, S = 300;
vector<int> nei[N], lv[N] ;
map<int, vector<int>> co[N];
int it[N], ot[N], cnt[N], dp[N], c[N], dep[N], sz[N];
int tm = 0;
map<int, int> *vl[N] = {};
void dfs(int u, int h = 0)
{
dep[u] = h;
co[c[u]][h].push_back(u);
it[u] = tm++;
sz[u] = 1;
lv[h].push_back(it[u]);
for (auto i : nei[u])
dfs(i, h + 1), sz[u] += sz[i];
ot[u] = tm - 1;
}
int ans = 1, mx = 0;
void dfs1(int u)
{
int x = -1;
for (auto i : nei[u])
{
dfs1(i);
if (x == -1)
x = i;
if (sz[i] > sz[x])
x = i;
}
if (x != -1)
vl[u] = vl[x];
else
vl[u] = new map<int, int>();
(*vl[u])[c[u]]++;
for (auto i : nei[u])
{
if (i == x)
continue;
for (auto [j, k] : *(vl[i]))
(*vl[u])[j] += k;
}
if (dp[u] > ans)
ans = dp[u],
mx = dp[u] - (*vl[u])[c[u]];
if (dp[u] == ans)
mx = min(mx, dp[u] - (*vl[u])[c[u]]);
}
int fn(int u, int lev)
{
int z = lower_bound(begin(lv[lev]), end(lv[lev]), it[u]) - begin(lv[lev]);
int y = upper_bound(begin(lv[lev]), end(lv[lev]), ot[u]) - begin(lv[lev]);
return y - z;
}
void bfs(int s)
{
queue<int> S;
S.push(s);
int col = c[s];
map<int, int> lev;
while (S.size())
{
int f = S.front();
S.pop();
lev[dep[f]]++;
for (auto i : nei[f])
S.push(i);
}
int ans = 0;
for (auto [i, cnnvj] : lev)
ans += min(cnnvj, int(co[col][i].size()));
dp[s] = ans;
}
int check(int k, int u)
{
int ans = 0, mx = 0;
for (auto [i, inds] : co[k])
{
int sz = fn(u, i);
sz = min(sz, int(inds.size()));
ans += sz;
}
return ans;
}
void dfs2(int u, int cl)
{
if (c[u] == cl)
{
bfs(u);
return;
}
for (auto i : nei[u])
dfs2(i, cl);
}
int main()
{
int n, k;
scanf("%d %d", &n, &k);
for (int i = 0; i < n; i++)
{
scanf("%d", &c[i]);
cnt[c[i]]++;
}
bool subt1 = 1;
for (int i = 1; i < n; i++)
{
int pra;
scanf("%d", &pra);
if (pra != i - 1)
subt1 = 0;
nei[pra].push_back(i);
}
dfs(0);
for (int i = 0; i < n; i++)
if (cnt[c[i]] <= S)
dp[i] = check(c[i], i);
for (int i = 0; i < k; i++)
if (cnt[i] > S)
dfs2(0, i);
dfs1(0);
printf("%d %d\n", ans, mx);
}