Submission #1247966

#TimeUsernameProblemLanguageResultExecution timeMemory
1247966mosesmayerTeam Coding (EGOI24_teamcoding)C++20
100 / 100
1756 ms27720 KiB
#include <bits/stdc++.h> #define fi first #define se second #define pb push_back #define eb emplace_back using namespace std; typedef long long LL; typedef pair<int,int> pii; typedef vector<int> vi; template<class T> inline T re(){ T N = 0; char c = getchar(); bool neg = 0; for (; c < '0' || c > '9'; c = getchar()) neg |= c == '-'; for (; c >= '0' && c <= '9'; c = getchar()) N = (N<<3)+(N<<1) + c - '0'; return neg ? -N : N; } const int SQRT = 325; const int MX = 1e5; int n, k; int c[MX + 5], p[MX + 5], col_count[MX + 5]; int dep[MX + 5], dep_cnt[MX + 5], max_dep = 0; vi chld[MX + 5], is_col[MX + 5], at_dep[MX + 5]; int cur_col[MX + 5]; // at each dep, # with current color int cur_dep[MX + 5]; // current subtree, number of nodes at dep int cur_at[MX + 5]; // current subtree, number of nodes at dep with current col int ans_tot = 0, ans_moves = 0; int flat[MX + 5], in_tme[MX + 5], out_tme[MX + 5]; void process_whole_tree(int col) { // do for big color for (int i = 0; i < n; i++) { if (c[i] == col) cur_col[dep[i]]++; } queue<int> q; q.push(0); while (!q.empty()) { int u = q.front(); q.pop(); if (c[u] != col) { for (int nx : chld[u]) { q.push(nx); } } else { // process subtree rooted at u queue<int> tmp; tmp.push(u); while (!tmp.empty()) { int nx = tmp.front(); tmp.pop(); cur_dep[dep[nx]]++; if (c[nx] == col) { cur_at[dep[nx]]++; } for (int nxnx : chld[nx]) tmp.push(nxnx); } int moves = 0, cur_tot = 1; // start with 1 as root should be counted. for (int i = dep[u] + 1; i <= max_dep; i++) { int tmptmp = min(cur_dep[i], cur_col[i]); cur_tot += tmptmp; moves += tmptmp - cur_at[i]; } if (cur_tot > ans_tot) { ans_tot = cur_tot, ans_moves = moves; } else if (cur_tot == ans_tot) { ans_moves = min(ans_moves, moves); } cur_tot = 1, moves = 0; tmp.push(u); while (!tmp.empty()) { int nx = tmp.front(); tmp.pop(); cur_dep[dep[nx]]--; if (c[nx] == col) { cur_at[dep[nx]]--; } for (int nxnx : chld[nx]) tmp.push(nxnx); } } } for (int i = 0; i < n; i++) { if (c[i] == col) cur_col[dep[i]]--; } } void process_small(int col) { for (int x : is_col[col]) cur_col[dep[x]]++; // cerr << "cur_col: "; for (int i = 0; i <= max_dep; i++) cerr << "(" << i << ": " << cur_col[i] << ") "; cerr << '\n'; // cerr << "elts: "; for (int i : is_col[col]) cerr << " " << i; cerr << '\n'; for (int rt : is_col[col]) { int cur_tot = 1, moves = 0; // cerr << "-- root " << rt << " --\n"; vi deps; for (int nx : is_col[col]) { if (nx == rt) continue; deps.pb(dep[nx]); if (in_tme[nx] >= in_tme[rt] && in_tme[nx] <= out_tme[rt]) cur_at[dep[nx]]++; cur_dep[dep[nx]] = ( upper_bound(at_dep[dep[nx]].begin(), at_dep[dep[nx]].end(), flat[out_tme[rt]], [](int x, int y) { return in_tme[x] < in_tme[y]; }) - lower_bound(at_dep[dep[nx]].begin(), at_dep[dep[nx]].end(), rt, [](int x, int y) { return in_tme[x] < in_tme[y]; }) ); } sort(deps.begin(), deps.end()); deps.erase(unique(deps.begin(), deps.end()), deps.end()); // cerr << "cur_dep: "; for (int i = 0; i <= max_dep; i++) cerr << "(" << i << ": " << cur_dep[i] << ") "; cerr << '\n'; // cerr << "cur_at: "; for (int i = 0; i <= max_dep; i++) cerr << "(" << i << ": " << cur_at[i] << ") "; cerr << '\n'; // cerr << "deps: "; for (int i : deps) cerr << " " << i; cerr << '\n'; for (int dd : deps) { if (dd > dep[rt]) { int tmptmp = min(cur_dep[dd], cur_col[dd]); cur_tot += tmptmp; moves += tmptmp - cur_at[dd]; } } if (cur_tot > ans_tot) { ans_tot = cur_tot, ans_moves = moves; } else if (cur_tot == ans_tot) { ans_moves = min(ans_moves, moves); } // cerr << "cur : " << cur_tot << ' ' << moves << '\n'; for (int nx : is_col[col]) { if (nx == rt) continue; if (in_tme[nx] >= in_tme[rt] && in_tme[nx] <= out_tme[rt]) cur_at[dep[nx]]--; cur_dep[dep[nx]] = 0; } } // cerr << "\n\n"; for (int x : is_col[col]) cur_col[dep[x]]--; } int main() { /** * for each node x, calculate: * 1) number of nodes y with dep[y] > x with caveat each level has max card(z in subtree x with dep[z] fixed) * 2) min number of swaps */ n = re<int>(); k = re<int>(); for (int i = 0; i < n; i++) { col_count[c[i] = re<int>()]++; is_col[c[i]].pb(i); } for (int i = 1; i < n; i++) { p[i] = re<int>(); chld[p[i]].pb(i); } [&]() { int _tme = 0; function<void(int)> dfs_tme; dfs_tme = [&](int u) -> void{ flat[++_tme] = u; in_tme[u] = _tme; max_dep = max(max_dep, dep[u]); for (int nx : chld[u]) { dep[nx] = dep[u] + 1; dfs_tme(nx); } out_tme[u] = _tme; }; dfs_tme(0); } (); for (int i = 0; i < n; i++) { dep_cnt[dep[i]]++; at_dep[dep[i]].pb(i); } for (int i = 0; i <= max_dep; i++) { sort(at_dep[i].begin(), at_dep[i].end(), [](int x, int y) { return in_tme[x] < in_tme[y];}); } for (int i = 0; i < k; i++) { if (col_count[i] >= SQRT) process_whole_tree(i); else process_small(i); } printf("%d %d\n", ans_tot, ans_moves); return 0; }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...