제출 #1090843

#제출 시각아이디문제언어결과실행 시간메모리
1090843ShaShiCat in a tree (BOI17_catinatree)C++17
100 / 100
221 ms72396 KiB
#include <bits/stdc++.h>
// #define int long long 
#define F first
#define S second
#define all(x) x.begin(), x.end()
#define kill(x) cout << x << endl, exit(0);
#define pb push_back
#define mp make_pair
#define pii pair<int, int>
#define endl "\n"
 
 
using namespace std;
typedef long long ll;
typedef long double ld;
 
 
const int MAXN = (int)2e5 + 7;
const int MOD = 998244353;
const int INF = (int)1e18 + 7;
const int LG = 30;
 
 
int n, m, t, flag, k, u, v, w, ans, ans2, tmp, tmp2, tmp3, tmp4;
int arr[MAXN], sz[MAXN], res[MAXN], h[MAXN];
bool hate[MAXN];
vector<int> adj[MAXN], fin;
pii jad[MAXN][LG];
int pnt[MAXN];
vector<pii> vec;
 
 
void DFSsz(int v, int p=-1) {
    sz[v] = 1;
 
    for (int u:adj[v]) {
        if (u == p || hate[u]) continue;
 
        DFSsz(u, v);
        sz[v] += sz[u];
    }
}
 
 
inline int centroid(int tot, int v, int p=-1) {
    for (int u:adj[v]) if (u != p && !hate[u] && 2*sz[u] > tot) return centroid(tot, u, v);
    return v;
}
 
 
void upd_dist(int v, int cent, int p=-1, int dis=1) {
    for (int u:adj[v]) if (u != p && !hate[u]) upd_dist(u, cent, v, dis+1);
    jad[v][pnt[v]++] = {cent, dis};
}
 
 
void solve(int v) {
    DFSsz(v);
    int cent = centroid(sz[v], v);
 
    hate[cent] = 1;
    for (int u:adj[cent]) if (!hate[u]) upd_dist(u, cent, cent);
    for (int u:adj[cent]) if (!hate[u]) solve(u);
}
 
 
inline void add(int v) {
    for (int i=0; i<pnt[v]; i++) res[jad[v][i].F] = min(res[jad[v][i].F], jad[v][i].S);
    res[v] = 0;
}
 
 
inline int get(int v) {
    ans = res[v];
    for (int i=0; i<pnt[v]; i++) ans = min(ans, jad[v][i].S+res[jad[v][i].F]);
    return ans;
}
 
 
void DFS(int v, int p=-1) {
    vec.pb({h[v], v});
 
    for (int u:adj[v]) {
        if (u == p) continue;
 
        h[u] = h[v]+1; DFS(u, v);
    }
}
 
 
int32_t main() {
    ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
 
    cin >> n >> k;
 
    for (int i=2; i<=n; i++) {
        cin >> u; u++;
 
        adj[u].pb(i); adj[i].pb(u);
    }
 
    solve(n); DFS(1); fill(res, res+n+1, n+1);
    sort(all(vec), greater<pii>());
 
    for (int i=0; i<n; i++) if (get(vec[i].S) >= k) add(vec[i].S), fin.pb(vec[i].S);
 
    cout << fin.size() << endl;
    // for (int u:fin) cout << u << " ";
    // cout << endl;
 
 
    return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...