이 제출은 이전 버전의 oj.uz에서 채점하였습니다. 현재는 제출 당시와는 다른 서버에서 채점을 하기 때문에, 다시 제출하면 결과가 달라질 수도 있습니다.
#include <iostream>
#include <algorithm>
#include <functional>
#include <random>
#include <cmath>
#include <vector>
#include <array>
#include <set>
#include <map>
#include <queue>
#include <cassert>
#include <string>
#include <bitset>
#include <numeric>
#include <iomanip>
#include <limits.h>
#include <tuple>
using namespace std;
#define rep(i,l,r) for(int i = l; i < (r); i++)
#define per(i,r,l) for(int i = r; i >= (l); i--)
#define sz(x) (int)size(x)
#define pb push_back
#define all(x) begin(x), end(x)
#define ff first
#define ss second
typedef long long ll;
typedef pair<int, int> pp;
const ll mod = 1e9+7, maxn = 2e5 + 5, inf = ll(1e9) + 5, lg = 20;
vector<int> adj[maxn];
int cnt[maxn], h[maxn], par[maxn], min_dist[maxn], nxt[maxn][lg];
void dfslca(int r, int p = 0){
nxt[r][0] = p;
rep(i,1,lg) nxt[r][i] = nxt[nxt[r][i-1]][i-1];
for(int c: adj[r]) if(c - p) h[c] = h[r] + 1, dfslca(c, r);
}
int lca(int u, int v){
if(h[u] > h[v]) swap(u, v);
per(i,lg-1,0) if(h[nxt[v][i]] >= h[u]) v = nxt[v][i];
if(u == v) return u;
per(i,lg-1,0) if(nxt[u][i] - nxt[v][i]) u = nxt[u][i], v = nxt[v][i];
return nxt[u][0];
}
int dist(int u, int v){
return h[u] + h[v] - (h[lca(u, v)]<<1);
}
void dfs(int r, int p = -1){
cnt[r] = 1;
for(int c: adj[r]) if(par[c] == -2 && c - p) dfs(c, r), cnt[r] += cnt[c];
}
int find_cnt(int r, int bs, int p = -1){
for(int c: adj[r]) if(c - p && par[c] == -2 && (cnt[c]<<1) > cnt[bs]) return find_cnt(c, bs, r);
return r;
}
void dec(int r, int p = -1){
dfs(r);
int cn = find_cnt(r, r);
par[cn] = p;
for(int c: adj[cn]) if(par[c] == -2) dec(c, cn);
}
int get(int u){
int ans = inf;
for(int cr = u; cr + 1; cr = par[cr]) ans = min(ans, min_dist[cr] + dist(cr, u));
return ans;
}
void upd(int u){
for(int cr = u; cr + 1; cr = par[cr]) min_dist[cr] = min(min_dist[cr], dist(cr, u));
}
int main(){
cin.tie(0) -> sync_with_stdio(0);
int n, d; cin >> n >> d;
rep(i,1,n){
int p; cin >> p;
adj[p].pb(i), adj[i].pb(p);
}
fill(par, par + n, -2);
fill(min_dist, min_dist + n, inf);
dec(0);
dfslca(0);
vector<int> node(n); iota(all(node), 0), sort(all(node), [&](int u, int v){ return h[u] > h[v]; });
int ans = 0;
for(int u: node){
if(get(u) >= d) upd(u), ans++;
}
cout << ans << '\n';
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |