#include<bits/stdc++.h>
using namespace std;
const int N = 2e5 + 5;
int n, D, it;
int in[N], out[N], h[N];
int up[20][N];
vector<int> ad[N];
void init(int u, int p) {
in[u] = ++it;
if(u == 1) for(int i = 0; i <= 18; ++i) up[i][u] = u;
else for(int i = 1; i <= 18; ++i) up[i][u] = up[i - 1][up[i - 1][u]];
for(auto v : ad[u]) {
if(v == p) continue;
up[0][v] = u;
h[v] = h[u] + 1;
init(v, u);
}
out[u] = it;
}
int anc(int u, int v) {
return (in[u] <= in[v] && out[v] <= out[u]);
}
int lca(int u, int v) {
if(anc(u, v)) return u;
int ret = u;
for(int i = 18; i >= 0; --i)
if(anc(up[i][u], v)) ret = up[i][u];
else u = up[i][u];
return ret;
}
int dist(int u, int v) {
return h[u] + h[v] - 2 * h[lca(u, v)];
}
namespace sub1 {
int ans = 0;
void solve() {
for(int mask = 0; mask < (1 << n); ++mask) {
bool ok = 1;
for(int i = 1; i <= n; ++i)
for(int j = 1; j < i; ++j)
if(((mask >> i - 1) & 1) && ((mask >> j - 1) & 1) && dist(i, j) < D) { ok = 0; break; }
if(ok) ans = max(ans, __builtin_popcount(mask));
}
cout << ans;
}
}
namespace sub2 {
int ans = 0;
void solve() {
vector<int> v;
for(int i = 1; i <= n; ++i) v.push_back(i);
while(v.size()) {
++ans;
int mx = 0;
for(auto u : v)
if(h[u] > h[mx]) mx = u;
vector<int> nxt;
for(auto u : v)
if(dist(u, mx) >= D) nxt.push_back(u);
v = nxt;
}
cout << ans;
}
}
// lấy dần các con ở độ sâu thấp nhất
namespace sub3 {
#define ii pair<int, int>
priority_queue<ii, vector<ii>, greater<ii>> q[N];
void dfs(int u, int p) {
for(auto v : ad[u]) {
if(v == p) continue;
dfs(v, u);
while(q[u].size() && q[v].size() &&
q[u].top().first + q[v].top().first - 2 * h[lca(q[u].top().second, q[v].top().second)] < D) {
if(q[u].top().first < q[v].top().first) q[u].pop();
else q[v].pop();
}
if(q[u].size() < q[v].size()) swap(q[u], q[v]);
while(q[v].size()) q[u].push(q[v].top()), q[v].pop();
}
if(q[u].empty() || q[u].top().first + h[u] - 2 * h[lca(q[u].top().second, u)] >= D) q[u].emplace(h[u], u);
}
void solve() {
dfs(1, 0);
cout << q[1].size() << '\n';
}
}
int32_t main() {
cin.tie(0)->sync_with_stdio(0);
cin >> n >> D;
for(int i = 2; i <= n; ++i) {
int u; cin >> u; ++u;
ad[u].push_back(i);
ad[i].push_back(u);
}
init(1, 0);
if(n <= 18) sub1 :: solve();
else if(n <= 1'500) sub2 :: solve();
else sub3 :: solve();
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |