제출 #1302081

#제출 시각아이디문제언어결과실행 시간메모리
1302081nguynCat in a tree (BOI17_catinatree)C++20
100 / 100
265 ms41520 KiB
#include <bits/stdc++.h>
using namespace std;

#ifdef LOCAL
#include "debug.h"
#else  
#define debug(...)
#endif 

#define ll long long
#define F first
#define S second
#define pb push_back
#define pii pair<int, int>

const int N = 2e5 + 5;

int n, k;
vector<int> g[N]; 
vector<pii> cur; 
vector<int> ans;
int mx[N]; 
int h[N]; 
int vis[N];
int del[N];     
int par[N]; 
int up[N][20]; 

void dfs(int u, int p) {
    cur.pb({h[u], u}); 
    for (int v : g[u]) {
        if (v == p) continue;
        h[v] = h[u] + 1;
        up[v][0] = u;
        for (int i = 1; i < 20; i++) {
            up[v][i] = up[up[v][i - 1]][i - 1]; 
        }
        dfs(v, u); 
    }
}

int lca(int u, int v) {
    if (h[u] != h[v]) {
        if (h[u] < h[v]) swap(u, v);
        int k = h[u] - h[v];
        for (int i = 0; (1 << i) <= k; i++) {
            if (k >> i & 1) {
                u = up[u][i]; 
            }
        }
    }
    if (u == v) {
        return u; 
    }
    for (int i = 19; i >= 0; i--) {
        if (up[v][i] != up[u][i]) {
            u = up[u][i];
            v = up[v][i]; 
        }
    }
    return up[u][0]; 
}

int dist(int u, int v) {
    // debug(u, v, lca(u, v)); 
    return h[u] + h[v] - 2 * h[lca(u, v)]; 
}

int sz[N]; 
int count_child(int u, int p) {
    sz[u] = 1;
    for (int v : g[u]) {
        if (v == p || del[v]) continue;
        sz[u] += count_child(v, u); 
    }
    return sz[u]; 
}

int find_centroid(int u, int p, int siz) {
    for (int v : g[u]) {
        if (v == p || del[v]) continue;
        if (sz[v] > siz / 2) return find_centroid(v, u, siz); 
    }
    return u; 
}

int decom(int u) {
    int siz = count_child(u, 0);
    int root = find_centroid(u, 0, siz); 
    del[root] = 1;

    for (int v : g[root]) {
        if (del[v]) continue;
        int cen = decom(v); 
        par[cen] = root; 
    }
    return root;
}

signed main() {
	ios_base::sync_with_stdio(0);
	cin.tie(0); 
	cin >> n >> k;
	for (int i = 2; i <= n; i++) {
        int u, v;
        v = i;
        cin >> u;
        u++;
        g[u].pb(v);
        g[v].pb(u); 
    }
    dfs(1, 0); 
    sort(cur.begin(), cur.end(), greater<pii>());
    decom(1); 
    for (auto [d, u] : cur) {
        int v = u;
        bool ok = 1;
        while(v != 0) { 
            int cur = dist(v, u); 
            if (cur < mx[v]) {
                ok = 0;
                break; 
            } 
            v = par[v]; 
        }
        if (!ok) continue;
        v = u; 
        // debug(v); 
        ans.pb(v); 
        while(v != 0) {
            int cur = dist(v, u); 
            mx[v] = max(mx[v], k - cur); 
            v = par[v]; 
        }
    }  
    cout << ans.size() << '\n';
    // for (int u : ans) {
    //     cout << u << ' ';
    // }
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...