제출 #1339382

#제출 시각아이디문제언어결과실행 시간메모리
1339382f0rizenCat in a tree (BOI17_catinatree)C++20
100 / 100
273 ms49092 KiB
#include <bits/stdc++.h>

using namespace std;
using ll = long long;
const int inf = 1e9 + 7;
const ll infll = 1e18;

template<typename T>
istream &operator>>(istream &is, vector<T> &a) {
    for (auto &i : a) {
        is >> i;
    }
    return is;
}

vector<vector<int>> g;
vector<bool> used;
vector<int> sz;
vector<int> par;
vector<vector<int>> dist;
vector<pair<int, int>> D;

void dfs_sz(int v, int p = -1) {
    sz[v] = 1;
    for (auto u : g[v]) {
        if (u != p && !used[u]) {
            dfs_sz(u, v);
            sz[v] += sz[u];
        }
    }
}

void dfs_dist(int v, int p = -1, int d = 0) {
    dist[v].push_back(d);
    for (auto u : g[v]) {
        if (u != p && !used[u]) {
            dfs_dist(u, v, d + 1);
        }
    }
}

int centroid(int v, int p, int n) {
    for (auto u : g[v]) {
        if (u != p && !used[u]) {
            if (sz[u] * 2 > n) {
                return centroid(u, v, n);
            }
        }
    }
    return v;
}

void build(int v, int p = -1) {
    dfs_sz(v);
    par[v] = p;
    dfs_dist(v);
    used[v] = 1;
    for (auto u : g[v]) {
        if (!used[u]) {
            build(centroid(u, v, sz[u]), v);
        }
    }
}

void dfs(int v, int p = -1, int d = 0) {
    D.push_back({d, v});
    for (auto u : g[v]) {
        if (u != p) {
            dfs(u, v, d + 1);
        }
    }
}

int32_t main() {
#ifdef LOCAL
    freopen("/tmp/input.txt", "r", stdin);
#else
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
#endif
    int n, d;
    cin >> n >> d;
    g.resize(n);
    for (int i = 1; i < n; ++i) {
        int p;
        cin >> p;
        g[p].push_back(i);
        g[i].push_back(p);
    }
    used.resize(n);
    sz.resize(n);
    par.resize(n);
    dist.resize(n);
    dfs_sz(0);
    int c = centroid(0, -1, sz[0]);
    build(c);
    dfs(c);
    sort(D.rbegin(), D.rend());
    vector<int> closest(n, inf);
    int ans = 0;
    for (auto [_, v] : D) {
        int u = v;
        int j = (int) dist[v].size() - 1;
        int mn = inf;
        while (u != -1) {
            if (closest[u] < inf) {
                mn = min(mn, dist[v][j] + closest[u]);
            }
            u = par[u];
            --j;
        }
        if (mn >= d) {
            int u = v;
            int j = (int) dist[v].size() - 1;
            while (u != -1) {
                closest[u] = min(closest[u], dist[v][j]);
                u = par[u];
                --j;
            }
            ++ans;
        }
    }
    cout << ans << "\n";
    return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...