Submission #591522

# Submission time Handle Problem Language Result Execution time Memory
591522 2022-07-07T14:34:00 Z MohammadAghil Cat in a tree (BOI17_catinatree) C++17
0 / 100
3 ms 4948 KB
#include <iostream>
#include <algorithm>
#include <functional>
#include <random>
#include <cmath>
#include <vector>
#include <array>
#include <set>
#include <map>
#include <queue>
#include <cassert>
#include <string>
#include <bitset>
#include <numeric>
#include <iomanip>
#include <limits.h>
#include <tuple>
using namespace std;

#define rep(i,l,r) for(int i = l; i < (r); i++)
#define per(i,r,l) for(int i = r; i >= (l); i--)
#define sz(x) (int)size(x)
#define pb push_back
#define all(x) begin(x), end(x)
#define ff first
#define ss second
typedef long long ll;
typedef pair<int, int> pp;

const ll mod = 1e9+7, maxn = 2e5 + 5, inf = ll(1e9) + 5, lg = 20;

vector<int> adj[maxn];
int cnt[maxn], h[maxn], par[maxn], min_dist[maxn], t, lgg[maxn<<1], st[maxn];
pp rmq[maxn<<1][lg];

void build(){
    rep(i,2,t+1) lgg[i] = lgg[i>>1] + 1;
    rep(j,1,lg) rep(i,0,t-(1<<j)+1) rmq[i][j] = min(rmq[i][j-1], rmq[i+(1<<(j-1))][j-1]);
}

void dfslca(int r, int p = 0){
    st[r] = t, rmq[t++][0] = {h[r], r};
    for(int c: adj[r]) if(c - p) h[c] = h[r] + 1, dfslca(c, r);
}

int lca(int u, int v){
    if(st[u] > st[v]) swap(u, v);
    int k = lgg[st[v] - st[u] + 1];
    return min(rmq[st[u]][k], rmq[st[v] - (1<<k) + 1][k]).ss;
}

int dist(int u, int v){
    return h[u] + h[v] - (h[lca(u, v)]<<1);
}

void dfs(int r, int p = -1){
    cnt[r] = 1;
    for(int c: adj[r]) if(par[c] == -2 && c - p) dfs(c, r), cnt[r] += cnt[c];
}

int find_cnt(int r, int bs, int p = -1){
    for(int c: adj[r]) if(c - p && par[c] == -2 && (cnt[c]<<1) > cnt[bs]) return find_cnt(c, bs, r);
    return r;
}

void dec(int r, int p = -1){
    dfs(r);
    int cn = find_cnt(r, r);
    par[cn] = p;
    for(int c: adj[cn]) if(par[c] == -2) dec(c, cn);
}

int get(int u){
    int ans = inf;
    for(int cr = u; cr + 1; cr = par[cr]) ans = min(ans, min_dist[cr] + dist(cr, u));
    return ans;
}

void upd(int u){
    for(int cr = u; cr + 1; cr = par[cr]) min_dist[cr] = min(min_dist[cr], dist(cr, u));
}

int main(){
    cin.tie(0) -> sync_with_stdio(0);
    int n, d; cin >> n >> d;
    rep(i,1,n){
        int p; cin >> p;
        adj[p].pb(i), adj[i].pb(p);
    }
    fill(par, par + n, -2);
    fill(min_dist, min_dist + n, inf);
    dec(0), dfslca(0), build();
    vector<int> node(n); iota(all(node), 0), sort(all(node), [&](int u, int v){ return h[u] > h[v]; });
    int ans = 0;
    for(int u: node){
        if(get(u) >= d) upd(u), ans++;
    }
    cout << ans << '\n';
}
# Verdict Execution time Memory Grader output
1 Incorrect 3 ms 4948 KB Output isn't correct
2 Halted 0 ms 0 KB -
# Verdict Execution time Memory Grader output
1 Incorrect 3 ms 4948 KB Output isn't correct
2 Halted 0 ms 0 KB -
# Verdict Execution time Memory Grader output
1 Incorrect 3 ms 4948 KB Output isn't correct
2 Halted 0 ms 0 KB -