Submission #1305638

#TimeUsernameProblemLanguageResultExecution timeMemory
1305638syanvuTeam Coding (EGOI24_teamcoding)C++20
100 / 100
2063 ms68556 KiB
#include <bits/stdc++.h>

#define pb push_back
#define SS ios_base::sync_with_stdio(0);cin.tie(nullptr);cout.tie(nullptr);
// #define int long long
#define all(v) v.begin(),v.end()
using namespace std;

const int N = 1e5 + 1, inf = 1e9 + 1, mod = 998244353;

int n, k;
int K = 200;
vector<int> g[N], e[N];
int a[N];
int cnt[N], in[N], tin[N], tout[N], timer, sz[N], boss[N], mark[N], dep[N], ver[N];
unordered_map<int, int> col[N], h[N];
int ans, sw;

void calc(int v){
    tin[v] = ++timer;
    ver[timer] = v;
    col[a[v]][dep[v]]++;
    sz[v] = 1;
    for(int to : g[v]){
        dep[to] = dep[v] + 1;
        calc(to);
        sz[v] += sz[to];
        if(!boss[v] || sz[boss[v]] < sz[to]) boss[v] = to;
    }
    tout[v] = timer;
}

void f(int v, int c){
    cnt[dep[v]]++;
    if(a[v] == c) in[dep[v]]++;
    for(int to : g[v]){
        f(to, c);
    }
}
void dfs1(int v, int c){
    if(a[v] == c){
        f(v, c);
        int cur = 0, cursw = 0, i = dep[v];
        while(cnt[i]){
            cur += in[i];
            cur += min(col[c][i], cnt[i]) - in[i];
            cursw += min(col[c][i], cnt[i]) - in[i];
            cnt[i] = in[i] = 0;
            i++;
        }
        // if(a[v] == 3) cout << cur;
        if(ans < cur || (ans == cur && cursw < sw)){
            ans = cur;
            sw = cursw;
        }
        return;
    }
    for(int to : g[v]){
        dfs1(to, c);
    }
}
bool ch(int u, int v){
    return (tin[u] <= tin[v] && tout[u] >= tout[v]);
}
void dfs2(int v){
    for(int to : g[v]){
        if(boss[v] == to) continue;
        dfs2(to);
    }
    if(boss[v]){ 
        dfs2(boss[v]);
        swap(h[v], h[boss[v]]);
    }
    h[v][dep[v]]++;
    for(int to : g[v]){
        if(boss[v] == to) continue;
        for(auto [d, s] : h[to]) h[v][d] += s;
    }
    if(mark[a[v]]){
        unordered_set<int> st;
        for(int x : e[a[v]]){
            if(ch(v, x)) in[dep[x]]++;
            st.insert(dep[x]);
        }
        int cur = 0, cursw = 0;
        for(int d : st){
            cur += in[d];
            cur += min(col[a[v]][d], h[v][d]) - in[d];
            cursw += min(col[a[v]][d], h[v][d]) - in[d];
            // if(a[v] == 2) cout << in[d] << ' ' << d << '\n';
            in[d] = 0;
        }
        // cout << a[v] << ' ' << cursw << '\n';
        if(cur > ans || (ans == cur && cursw < sw)){
            ans = cur;
            sw = cursw;
        }
    }
}

void solve(){
    cin >> n >> k;
    for(int i = 1; i <= n; i++){
        cin >> a[i];
        a[i]++;
        e[a[i]].push_back(i);
    }
    for(int i = 2; i <= n; i++){
        int p;
        cin >> p;
        p++;
        g[p].push_back(i);
    }
    calc(1);
    vector<int> heavy;
    for(int i = 1; i <= k; i++){
        if(e[i].size() > K) heavy.push_back(i);
        else mark[i] = 1;
    }
    for(int i : heavy){
        dfs1(1, i);
    }
    dfs2(1);
    cout << ans << ' ' << sw;
}
signed main(){ 
    SS
    // freopen("trains.in", "r", stdin);
    // freopen("trains.out", "w", stdout);
    int t = 1;
    // cin >> t;
    while(t--){
        solve();
    }
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...