제출 #1264367

#제출 시각아이디문제언어결과실행 시간메모리
1264367testaccountTeam Coding (EGOI24_teamcoding)C++20
100 / 100
1767 ms27720 KiB
#include <bits/stdc++.h>
#define fi first
#define se second
#define pb push_back
#define eb emplace_back
using namespace std;
typedef long long LL;
typedef pair<int,int> pii;
typedef vector<int> vi;

template<class T> inline T re(){
    T N = 0; char c = getchar(); bool neg = 0;
    for (; c < '0' || c > '9'; c = getchar()) neg |= c == '-';
    for (; c >= '0' && c <= '9'; c = getchar())
        N = (N<<3)+(N<<1) + c - '0';
    return neg ? -N : N;
}

const int SQRT = 325;
const int MX = 1e5;
int n, k;
int c[MX + 5], p[MX + 5], col_count[MX + 5];
int dep[MX + 5], dep_cnt[MX + 5], max_dep = 0;
vi chld[MX + 5], is_col[MX + 5], at_dep[MX + 5];

int cur_col[MX + 5]; // at each dep, # with current color
int cur_dep[MX + 5]; // current subtree, number of nodes at dep
int cur_at[MX + 5]; // current subtree, number of nodes at dep with current col
int ans_tot = 0, ans_moves = 0;

int flat[MX + 5], in_tme[MX + 5], out_tme[MX + 5];

void process_whole_tree(int col) { // do for big color
    for (int i = 0; i < n; i++) {
        if (c[i] == col) cur_col[dep[i]]++;
    }
    queue<int> q;
    q.push(0);
    while (!q.empty()) {
        int u = q.front(); q.pop();
        if (c[u] != col) {
            for (int nx : chld[u]) {
                q.push(nx);
            }
        } else {
            // process subtree rooted at u
            queue<int> tmp; tmp.push(u);
            while (!tmp.empty()) {
                int nx = tmp.front(); tmp.pop();
                cur_dep[dep[nx]]++;
                if (c[nx] == col) {
                    cur_at[dep[nx]]++;
                }
                for (int nxnx : chld[nx]) tmp.push(nxnx);
            }
            int moves = 0, cur_tot = 1; // start with 1 as root should be counted.
            for (int i = dep[u] + 1; i <= max_dep; i++) {
                int tmptmp = min(cur_dep[i], cur_col[i]);
                cur_tot += tmptmp;
                moves += tmptmp - cur_at[i];
            }
            if (cur_tot > ans_tot) {
                ans_tot = cur_tot, ans_moves = moves;
            } else if (cur_tot == ans_tot) {
                ans_moves = min(ans_moves, moves);
            }
            cur_tot = 1, moves = 0;
            tmp.push(u);
            while (!tmp.empty()) {
                int nx = tmp.front(); tmp.pop();
                cur_dep[dep[nx]]--;
                if (c[nx] == col) {
                    cur_at[dep[nx]]--;
                }
                for (int nxnx : chld[nx]) tmp.push(nxnx);
            }
        }
    }
    for (int i = 0; i < n; i++) {
        if (c[i] == col) cur_col[dep[i]]--;
    }
}

void process_small(int col) {
    for (int x : is_col[col]) cur_col[dep[x]]++;
    // cerr << "cur_col: "; for (int i = 0; i <= max_dep; i++) cerr << "(" << i << ": " << cur_col[i] << ") "; cerr << '\n';
    // cerr << "elts: "; for (int i : is_col[col]) cerr << " " << i; cerr << '\n';
    for (int rt : is_col[col]) {
        int cur_tot = 1, moves = 0;
        // cerr << "-- root " << rt << " --\n";
        vi deps;
        for (int nx : is_col[col]) {
            if (nx == rt) continue;
            deps.pb(dep[nx]);
            if (in_tme[nx] >= in_tme[rt] && in_tme[nx] <= out_tme[rt]) cur_at[dep[nx]]++;
            cur_dep[dep[nx]] = (
                upper_bound(at_dep[dep[nx]].begin(), at_dep[dep[nx]].end(), flat[out_tme[rt]], [](int x, int y) {
                        return in_tme[x] < in_tme[y];
                    }) -
                lower_bound(at_dep[dep[nx]].begin(), at_dep[dep[nx]].end(), rt, [](int x, int y) {
                        return in_tme[x] < in_tme[y];
                    })
            );
        }

        sort(deps.begin(), deps.end());
        deps.erase(unique(deps.begin(), deps.end()), deps.end());

        // cerr << "cur_dep: "; for (int i = 0; i <= max_dep; i++) cerr << "(" << i << ": " << cur_dep[i] << ") "; cerr << '\n';
        // cerr << "cur_at: "; for (int i = 0; i <= max_dep; i++) cerr << "(" << i << ": " << cur_at[i] << ") "; cerr << '\n';
        // cerr << "deps: "; for (int i : deps) cerr << " " << i; cerr << '\n';

        for (int dd : deps) {
            if (dd > dep[rt]) {
                int tmptmp = min(cur_dep[dd], cur_col[dd]);
                cur_tot += tmptmp;
                moves += tmptmp - cur_at[dd];
            }
        }
        if (cur_tot > ans_tot) {
            ans_tot = cur_tot, ans_moves = moves;
        } else if (cur_tot == ans_tot) {
            ans_moves = min(ans_moves, moves);
        }
        // cerr << "cur : " << cur_tot << ' ' << moves << '\n';

        for (int nx : is_col[col]) {
            if (nx == rt) continue;
            if (in_tme[nx] >= in_tme[rt] && in_tme[nx] <= out_tme[rt]) cur_at[dep[nx]]--;
            cur_dep[dep[nx]] = 0;
        }
    }
    // cerr << "\n\n";
    for (int x : is_col[col]) cur_col[dep[x]]--;
}

int main() {
    /**
     * for each node x, calculate:
     * 1) number of nodes y with dep[y] > x with caveat each level has max card(z in subtree x with dep[z] fixed)
     * 2) min number of swaps
     */
    n = re<int>(); k = re<int>();
    for (int i = 0; i < n; i++) {
        col_count[c[i] = re<int>()]++;
        is_col[c[i]].pb(i);
    }
    for (int i = 1; i < n; i++) {
        p[i] = re<int>();
        chld[p[i]].pb(i);
    }
    [&]() {
        int _tme = 0;
        function<void(int)> dfs_tme;
        dfs_tme = [&](int u) -> void{
             flat[++_tme] = u;
             in_tme[u] = _tme;
             max_dep = max(max_dep, dep[u]);
             for (int nx : chld[u]) {
                 dep[nx] = dep[u] + 1;
                 dfs_tme(nx);
             }
             out_tme[u] = _tme;
        };
        dfs_tme(0);
    } ();
    for (int i = 0; i < n; i++) {
        dep_cnt[dep[i]]++;
        at_dep[dep[i]].pb(i);
    }
    for (int i = 0; i <= max_dep; i++) {
        sort(at_dep[i].begin(), at_dep[i].end(), [](int x, int y) { return in_tme[x] < in_tme[y];});
    }
    for (int i = 0; i < k; i++) {
        if (col_count[i] >= SQRT) process_whole_tree(i);
        else process_small(i);
    }

    printf("%d %d\n", ans_tot, ans_moves);
    return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...