제출 #591523

#제출 시각아이디문제언어결과실행 시간메모리
591523MohammadAghilCat in a tree (BOI17_catinatree)C++17
100 / 100
447 ms88588 KiB
#include <iostream> #include <algorithm> #include <functional> #include <random> #include <cmath> #include <vector> #include <array> #include <set> #include <map> #include <queue> #include <cassert> #include <string> #include <bitset> #include <numeric> #include <iomanip> #include <limits.h> #include <tuple> using namespace std; #define rep(i,l,r) for(int i = l; i < (r); i++) #define per(i,r,l) for(int i = r; i >= (l); i--) #define sz(x) (int)size(x) #define pb push_back #define all(x) begin(x), end(x) #define ff first #define ss second typedef long long ll; typedef pair<int, int> pp; const ll mod = 1e9+7, maxn = 2e5 + 5, inf = ll(1e9) + 5, lg = 20; vector<int> adj[maxn]; int cnt[maxn], h[maxn], par[maxn], min_dist[maxn], t, lgg[maxn<<1], st[maxn]; pp rmq[maxn<<1][lg]; void build(){ rep(i,2,t+1) lgg[i] = lgg[i>>1] + 1; rep(j,1,lg) rep(i,0,t-(1<<j)+1) rmq[i][j] = min(rmq[i][j-1], rmq[i+(1<<(j-1))][j-1]); } void dfslca(int r, int p = 0){ st[r] = t, rmq[t++][0] = {h[r], r}; for(int c: adj[r]) if(c - p) h[c] = h[r] + 1, dfslca(c, r), rmq[t++][0] = {h[r], r}; } int lca(int u, int v){ if(st[u] > st[v]) swap(u, v); int k = lgg[st[v] - st[u] + 1]; return min(rmq[st[u]][k], rmq[st[v] - (1<<k) + 1][k]).ss; } int dist(int u, int v){ return h[u] + h[v] - (h[lca(u, v)]<<1); } void dfs(int r, int p = -1){ cnt[r] = 1; for(int c: adj[r]) if(par[c] == -2 && c - p) dfs(c, r), cnt[r] += cnt[c]; } int find_cnt(int r, int bs, int p = -1){ for(int c: adj[r]) if(c - p && par[c] == -2 && (cnt[c]<<1) > cnt[bs]) return find_cnt(c, bs, r); return r; } void dec(int r, int p = -1){ dfs(r); int cn = find_cnt(r, r); par[cn] = p; for(int c: adj[cn]) if(par[c] == -2) dec(c, cn); } int get(int u){ int ans = inf; for(int cr = u; cr + 1; cr = par[cr]) ans = min(ans, min_dist[cr] + dist(cr, u)); return ans; } void upd(int u){ for(int cr = u; cr + 1; cr = par[cr]) min_dist[cr] = min(min_dist[cr], dist(cr, u)); } int main(){ cin.tie(0) -> sync_with_stdio(0); int n, d; cin >> n >> d; rep(i,1,n){ int p; cin >> p; adj[p].pb(i), adj[i].pb(p); } fill(par, par + n, -2); fill(min_dist, min_dist + n, inf); dec(0), dfslca(0), build(); vector<int> node(n); iota(all(node), 0), sort(all(node), [&](int u, int v){ return h[u] > h[v]; }); int ans = 0; for(int u: node){ if(get(u) >= d) upd(u), ans++; } cout << ans << '\n'; }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...