이 제출은 이전 버전의 oj.uz에서 채점하였습니다. 현재는 제출 당시와는 다른 서버에서 채점을 하기 때문에, 다시 제출하면 결과가 달라질 수도 있습니다.
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 2e5 + 25;
vector <int> adj[MAXN];
int p[MAXN], sze[MAXN];
int depth[MAXN];
bool vis[MAXN];
void calc (int pos, int par) {
sze[pos] = 1;
for (auto j : adj[pos]) {
if (!vis[j] && j != par) {
calc(j, pos);
sze[pos] += sze[j];
}
}
}
int find (int pos, int par, int u) {
for (auto j : adj[pos]) {
if (j == par || vis[j]) continue;
if (sze[j] > u / 2) return find(j, pos, u);
}
return pos;
}
void dfs (int pos, int par) {
calc(pos, -1);
int x = find(pos, -1, sze[pos]);
vis[x] = 1;
p[x] = par;
for (auto j : adj[x]) {
if (!vis[j]) {
dfs(j, x);
}
}
}
int n, d;
struct LCA {
int p2[MAXN] = {}, dep[MAXN] = {}, dp[MAXN][18] = {};
void dfs (int pos, int par) {
p2[pos] = par;
for (auto j : adj[pos]) {
if (j == par) continue;
dep[j] = 1 + dep[pos];
dfs(j, pos);
}
}
int jump (int a, int b) {
for (int i = 17; i >= 0; i--) {
if ((b & (1 << i)) && dp[a][i]) a = dp[a][i];
else if (b & (1 << i)) return -1;
}
return a;
}
int lca (int a, int b) {
if (dep[a] < dep[b]) swap(a, b);
int u = dep[a] - dep[b];
a = jump(a, u);
if (a == b) return a;
for (int i = 17; i >= 0; i--) {
int x = dp[a][i], y = dp[b][i];
if (x && y && x != y) a = x, b = y;
}
return jump(a, 1);
}
int dist (int a, int b) { return dep[a] + dep[b] - 2 * dep[lca(a, b)]; }
void init () {
dfs(1, 0);
for (int i = 1; i <= n; i++) dp[i][0] = p2[i];
for (int i = 1; i <= 17; i++) for (int j = 1; j <= n; j++) dp[j][i] = dp[dp[j][i - 1]][i - 1];
}
};
LCA cur;
int pp[MAXN];
bool check (int x) {
int mn = 1e9;
int u = x;
while (u != -1) {
mn = min(mn, cur.dist(x, u) + pp[u]);
u = p[u];
}
return (mn >= d);
}
void mark (int x) {
int u = x;
while (u != -1) {
pp[u] = min(pp[u], cur.dist(u, x));
u = p[u];
}
}
int main () {
ios::sync_with_stdio(0);
cin.tie(0);
cin >> n >> d;
for (int i = 2; i <= n; i++) {
int x; cin >> x; x++;
adj[i].push_back(x);
adj[x].push_back(i);
}
cur.init();
dfs(1, -1);
vector <int> dd;
for (int i = 1; i <= n; i++) pp[i] = 1e9;
for (int i = 1; i <= n; i++) dd.push_back(i);
sort(dd.begin(), dd.end(), [&] (int &a, int &b) {
return cur.dep[a] > cur.dep[b];
});
int ans = 0;
vector <int> ll;
for (auto i : dd) if (check(i)) {
ans++; mark(i);
}
cout << ans << '\n';
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |