Submission #1124807

#TimeUsernameProblemLanguageResultExecution timeMemory
1124807salmonTeam Coding (EGOI24_teamcoding)C++20
81 / 100
4128 ms890384 KiB
#include <bits/stdc++.h>
using namespace std;

int N;
int K;
int lst[100100];
int parent[100100];
vector<int> children[100100];
int num[100100];
vector<int> srt;
vector<int> lc[100100];
const int B = 303;
int pre[100100];
int post[100100];
int d[100100];
int cont = 0;
int templst[100100];
bool lo[100100];
pair<int,int> ans;

void dfs(int i, int de){
    d[i] = de;
    pre[i] = cont;

    cont++;

    for(int j : children[i]){
        dfs(j,de + 1);
    }

    post[i] = cont - 1;
}

vector<int>* solve(int i){
    vector<vector<int>*> v;
    pair<int,int> p = {0,-1};
    vector<int>* vec = new vector<int>();

    for(int j : children[i]){
        v.push_back(solve(j));
        p = max(p,make_pair( (int) (v.back() -> size()), (int)v.size() - 1 ));
    }

    if(v.empty()){
        vec -> push_back(1);
    }
    else{
        vec = v[p.second];

        for(int j = 0; j < v.size(); j++){
            if(j == p.second) continue;

            int it = vec -> size() - 1;

            while(! (v[j] -> empty()) ){
                (*vec)[it] += v[j] -> back();
                v[j] -> pop_back();
                it--;
            }
        }

        vec -> push_back(1);
    }

    if(lo[lst[i]]){
        int k = lst[i];
        int ans = 0;
        int ans1 = 0;
        vector<int> acc;

        for(int j : lc[k]){
            if(pre[i] <= pre[j] && pre[j] <= post[i]) ans1++;
            templst[d[j]]++;
            if(d[j] >= d[i] && d[j] <= d[i] + vec -> size() - 1) acc.push_back(d[j]);
        }

        sort(acc.begin(),acc.end());
        acc.resize(unique(acc.begin(), acc.end()) - acc.begin());
        //printf("s: ");
        for(int a : acc){
            ans += min( (*vec)[vec -> size() - 1 - (a - d[i])], templst[a]);
        }
        //printf("\n");

        ans1 = ans - ans1;
        //printf("%d %d\n",ans,ans1);

        ::ans = max(::ans,{ans,-ans1});

        for(int j : lc[k]){
            templst[d[j]]--;
        }
    }

    return vec;
}

pair<vector<int>*,pair<int,int>> solve1(int i, int k){
    vector<vector<int>*> v;
    pair<int,int> p = {0,-1};
    vector<int>* vec = new vector<int>();
    vector<pair<int,int>> vii;

    for(int j : children[i]){
        pair<vector<int>*,pair<int,int>> iiii = solve1(j,k);
        vii.push_back(iiii.second);
        v.push_back(iiii.first);
        p = max(p,make_pair( (int) (v.back() -> size()), (int)v.size() - 1 ));
    }

    int ans = 0;
    int ans1 = 0;

    if(v.empty()){
        vec -> push_back(1);
        if(lst[i] == k) ans1++;

        if(templst[d[i]] > 0) ans = 1;
    }
    else{
        vec = v[p.second];
        ans = vii[p.second].first;
        ans1 = vii[p.second].second;

        for(int j = 0; j < v.size(); j++){
            if(j == p.second) continue;

            int it = vec -> size() - 1;

            while(! (v[j] -> empty()) ){
                ans -= min((*vec)[it], templst[d[i] + vec -> size() - it]);
                (*vec)[it] += v[j] -> back();
                ans += min((*vec)[it], templst[d[i] + vec -> size() - it]);
                v[j] -> pop_back();
                it--;
            }
        }

        vec -> push_back(1);

        for(int j = 0; j < v.size(); j++){
            if(j == p.second) continue;

            ans1 += vii[j].second;
        }

        if(lst[i] == k) ans1++;
        if(templst[d[i]] > 0) ans++;
    }

    if(lst[i] == k) ::ans = max(::ans,{ans,ans1-ans});

    return {vec,{ans,ans1}};
}

int main(){
    scanf(" %d",&N);
    scanf(" %d",&K);

    for(int i = 0; i < K; i++){
        num[i] = 0;
        lo[i] = false;
        templst[i] = 0;
    }

    for(int i = 0; i < N; i++){
        scanf(" %d",&lst[i]);
        num[lst[i]]++;
        lc[lst[i]].push_back(i);
    }

    parent[0] = -1;
    for(int i = 1; i < N; i++){
        scanf(" %d",&parent[i]);
        children[parent[i]].push_back(i);
    }

    dfs(0,0);

    for(int i = 0; i < K; i++){
        if(num[i] >= B){
            srt.push_back(i);
        }
        else lo[i] = true;
    }

    ans = {0,-1};

    delete solve(0);

    //printf("%d %d",ans.first,-ans.second);

    for(int i : srt){
        for(int i = 0; i < N; i++){
            templst[i] = 0;
        }
        for(int j : lc[i]){
            templst[d[j]]++;
        }
        delete solve1(0,i).first;
    }

    printf("%d %d",ans.first,-ans.second);
}

Compilation message (stderr)

Main.cpp: In function 'int main()':
Main.cpp:157:10: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
  157 |     scanf(" %d",&N);
      |     ~~~~~^~~~~~~~~~
Main.cpp:158:10: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
  158 |     scanf(" %d",&K);
      |     ~~~~~^~~~~~~~~~
Main.cpp:167:14: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
  167 |         scanf(" %d",&lst[i]);
      |         ~~~~~^~~~~~~~~~~~~~~
Main.cpp:174:14: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
  174 |         scanf(" %d",&parent[i]);
      |         ~~~~~^~~~~~~~~~~~~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...