Submission #1052731

#TimeUsernameProblemLanguageResultExecution timeMemory
1052731LudisseyTeam Coding (EGOI24_teamcoding)C++17
100 / 100
1150 ms72028 KiB
#include <bits/stdc++.h>
#define int long long
#define sz(a) (int)a.size()
#define all(a) a.begin(), a.end()

using namespace std;
vector<vector<int>> child;
vector<int> depthCOUNT;
vector<int> depth;
vector<vector<int>> parent;
int totMX=0;
int totSWITCH=0;
int SQRT;
vector<int> l;
int need;
const int LOG=31;

int lca(int a, int b){
    if(depth[a]>depth[b]) swap(a,b);
    int d=depth[b]-depth[a];
    for (int i = LOG-1; i >= 0; i--)
    {
        if(d&(1<<i)) b=parent[b][i];
    }
    if(a==b) return a;
    for (int i = LOG-1; i >= 0; i--){
        if(parent[a][i]!=parent[b][i]){
            a=parent[a][i];
            b=parent[b][i];
        }
    }
    return parent[a][0];
}
vector<int> depthHAS;
vector<int> depthNOTHAS;
vector<int> depthNEDED;
vector<vector<int>> lang;
int N;
int addALLCHILD(int x){
    if(l[x]==need) depthHAS[depth[x]]++;
    else depthNOTHAS[depth[x]]++;
    int sm=depth[x];
    for (auto u : child[x])
    {
        sm=max(sm,addALLCHILD(u));
    }
    return sm;
}
void dfs(int x){
    if(l[x]==need){
        int mx=addALLCHILD(x);        
        int sm=1;
        int swtch=0;
        for (int j = depth[x]+1; j <= mx;j++)
        {
            int gave=min(depthNEDED[j]-depthHAS[j],depthNOTHAS[j]);
            sm+=depthHAS[j]+gave;
            swtch+=gave;
        }
        for (int j = depth[x]; j <= mx;j++)
        {
            depthHAS[j]=0;
            depthNOTHAS[j]=0;
        }
        if(sm>totMX){
            totMX=sm;
            totSWITCH=swtch;
        }else if(sm==totMX){
            totSWITCH=min(totSWITCH,swtch);
        }
        return;
    }else{
        for (auto u : child[x])
        {
            dfs(u);
        }
        return;
    }
}

unordered_map<int,int> dfs2(int x){
    vector<unordered_map<int,int>> mps;
    for (auto u : child[x])
    {
        mps.push_back(dfs2(u));
    }
    mps.push_back({});
    mps[sz(mps)-1][depth[x]]++;
    sort(all(mps), [&](auto &aa, auto &bb) {return sz(aa)>sz(bb);});
    for (int i = 1; i < sz(mps); i++)
    {
        for (auto u : mps[i])
        {
            mps[0][u.first]+=u.second;
        }
    }
    unordered_map<int,pair<int,int>> dpth;
    if(sz(lang[l[x]])<SQRT){
        for (int i = 0; i < sz(lang[l[x]]); i++)
        {
            if(lang[l[x]][i]==x||depth[lang[l[x]][i]]<=depth[x]) continue;
            if(lca(lang[l[x]][i],x)==x){
                dpth[depth[lang[l[x]][i]]].first++;
            }else{
                dpth[depth[lang[l[x]][i]]].second++;
            }
        }
        int sm=1;
        int swtch=0;
        for (auto u : dpth)
        {
            int gave=min(u.second.second,mps[0][u.first]-u.second.first);
            sm+=u.second.first+gave;
            swtch+=gave;
        }
        if(sm>totMX){
            totMX=sm;
            totSWITCH=swtch;
        }else if(sm==totMX){
            totSWITCH=min(totSWITCH,swtch);
        }
    }
    return move(mps[0]);
}

signed main() {
    ios_base::sync_with_stdio(false); cin.tie(nullptr);
    int K; cin >> N >> K;
    l.resize(N);
    parent.resize(N,vector<int>(LOG,0));
    depth.resize(N);    
    child.resize(N);  
    lang.resize(K);  
    depthHAS.resize(N,0);
    depthNOTHAS.resize(N,0);
    SQRT=sqrt(N);
    unordered_map<int,int> cnt;
    for (int i = 0; i < N; i++) cin >> l[i];
    for (int i = 0; i < N; i++) cnt[l[i]]++;
    for (int i = 0; i < N; i++) lang[l[i]].push_back(i);
    
    int mx=cnt[l[0]];
    need=1-l[0];
    parent[0][0]=0;
    for (int i = 1; i < N; i++) {
        cin >> parent[i][0];
        child[parent[i][0]].push_back(i);
    }
    for (int j = 1; j < LOG; j++)
    {
        for (int i = 0; i < N; i++)
        {
            parent[i][j]=parent[parent[i][j-1]][j-1];
        }
    }
    depthNEDED.resize(N,0);    
    queue<int> que;
    que.push(0);
    while(!que.empty()){
        int x=que.front(); que.pop();
        for (auto u : child[x])
        {
            depth[u]=depth[x]+1;
            que.push(u);
        }
    }
    dfs2(0);
    for (int i = 0; i < K; i++)
    {
        if(sz(lang[i])>=SQRT){
            for (int j = 0; j < sz(lang[i]); j++) depthNEDED[depth[lang[i][j]]]++;
            need=i;
            dfs(0);
            for (int j = 0; j < sz(lang[i]); j++) depthNEDED[depth[lang[i][j]]]=0;;

        }
    }
    
    if(mx>=totMX){
        totMX=mx;
        totSWITCH=0;
    }
    cout << totMX <<  " " << totSWITCH << "\n";
    return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...