제출 #828403

#제출 시각아이디문제언어결과실행 시간메모리
828403serifefedartarCat in a tree (BOI17_catinatree)C++17
100 / 100
363 ms39928 KiB
#include <bits/stdc++.h> using namespace std; #define fast ios::sync_with_stdio(0);cin.tie(0); typedef long long ll; #define f first #define s second #define MOD 998244353 #define LOGN 18 #define MAXN 200005 vector<int> graph[MAXN]; vector<int> nearest; int up[LOGN][MAXN], depth[MAXN], par[MAXN], sz[MAXN]; bool marked[MAXN]; void dfs(int node, int parent) { for (auto u : graph[node]) { if (u == parent) continue; depth[u] = depth[node] + 1; up[0][u] = node; for (int i = 1; i < LOGN; i++) up[i][u] = up[i-1][up[i-1][u]]; dfs(u, node); } } int find(int node, int k) { for (int i = 0; i < LOGN; i++) { if ((1<<i) & k) node = up[i][node]; } return node; } int lca(int a, int b) { if (depth[b] > depth[a]) swap(a, b); a = find(a, depth[a] - depth[b]); if (a == b) return a; for (int i = LOGN-1; i >= 0; i--) { if (up[i][a] != up[i][b]) { a = up[i][a]; b = up[i][b]; } } return up[0][a]; } int dist(int a, int b) { return depth[a] + depth[b] - 2*depth[lca(a, b)]; } int get_sz(int node, int parent) { sz[node] = 1; for (auto u : graph[node]) { if (u == parent || marked[u]) continue; sz[node] += get_sz(u, node); } return sz[node]; } int find_centro(int node, int parent, int n) { for (auto u : graph[node]) { if (u == parent || marked[u]) continue; if (sz[u] * 2 > n) return find_centro(u, node, n); } return node; } void decompose(int node, int parent) { int n = get_sz(node, parent); int centro = find_centro(node, node, n); par[centro] = parent; marked[centro] = true; for (auto u : graph[centro]) { if (!marked[u]) decompose(u, centro); } } int main() { fast memset(par, -1, sizeof(par)); int N, D; cin >> N >> D; nearest = vector<int>(N, 1e9); for (int i = 1; i < N; i++) { int x; cin >> x; graph[x].push_back(i); graph[i].push_back(x); } for (int i = 0; i < LOGN; i++) up[i][0] = 0; dfs(0, 0); decompose(0, -1); vector<pair<int,int>> order; for (int i = 0; i < N; i++) order.push_back({depth[i], i}); sort(order.rbegin(), order.rend()); int ans = 0; for (auto u : order) { int node = u.s; int now = node; int mn_dist = 1e9; while (now != -1) { mn_dist = min(mn_dist, dist(node, now) + nearest[now]); now = par[now]; } if (mn_dist >= D) { ans++; int now = node; while (now != -1) { nearest[now] = min(nearest[now], dist(node, now)); now = par[now]; } } } cout << ans << "\n"; }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...