Submission #1129900

#TimeUsernameProblemLanguageResultExecution timeMemory
1129900vladiliusCat in a tree (BOI17_catinatree)C++20
0 / 100
0 ms328 KiB
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
using pii = pair<int, int>;
#define pb push_back
#define ff first
#define ss second
#define arr3 array<int, 3>

struct dsu{
    vector<int> sz, p;
    vector<pii> m;
    int n;
    dsu(int ns, vector<int>& d){
        n = ns;
        sz.resize(n + 1);
        p.resize(n + 1);
        m.resize(n + 1);
        for (int i = 1; i <= n; i++){
            p[i] = i;
            sz[i] = 1;
            m[i] = {d[i], i};
        }
    }
    int get(int v){
        if (p[v] != v){
            p[v] = get(p[v]);
        }
        return p[v];
    }
    void unite(int x, int y){
        x = get(x); y = get(y);
        if (x == y) return;
        if (sz[x] > sz[y]) swap(x, y);
        p[x] = y;
        sz[y] += sz[x];
        m[y] = min(m[y], m[x]);
    }
};

struct ST{
    vector<pair<pii, pii>> t;
    int n;
    ST(int ns){
        n = ns;
        t.resize(4 * n);
    }
    void upd(int v, int tl, int tr, int& p, int& x){
        if (tl == tr){
            t[v] = {{x, tl}, {x, tl}};
            return;
        }
        int tm = (tl + tr) / 2, vv = 2 * v;
        if (p <= tm){
            upd(vv, tl, tm, p, x);
        }
        else {
            upd(vv + 1, tm + 1, tr, p, x);
        }
        t[v].ff = min(t[vv].ff, t[vv + 1].ff);
        t[v].ss = max(t[vv].ss, t[vv + 1].ss);
    }
    void upd(int p, int x){
        upd(1, 1, n, p, x);
    }
    vector<int> rem;
    vector<arr3> all;
    void dec(int v, int tl, int tr, int& l, int& r){
        if (l > tr || r < tl) return;
        if (l <= tl && tr <= r){
            all.pb({v, tl, tr});
            return;
        }
        int tm = (tl + tr) / 2, vv = 2 * v;
        dec(vv, tl, tm, l, r);
        dec(vv + 1, tm + 1, tr, l, r);
    }
    void bye(int v, int tl, int tr, int& p){
        if (tl == tr){
            t[v] = {{1e9, 0}, {0, 0}};
            return;
        }
        int tm = (tl + tr) / 2, vv = 2 * v;
        if (p <= tm){
            bye(vv, tl, tm, p);
        }
        else {
            bye(vv + 1, tm + 1, tr, p);
        }
        t[v].ff = min(t[vv].ff, t[vv + 1].ff);
        t[v].ss = max(t[vv].ss, t[vv + 1].ss);
    }
    void find(int v, int tl, int tr, int& x){
        if (t[v].ff.ff > x) return;
        if (tl == tr){
            rem.pb(tl);
            bye(1, 1, n, tl);
            return;
        }
        int tm = (tl + tr) / 2, vv = 2 * v;
        find(vv, tl, tm, x);
        find(vv + 1, tm + 1, tr, x);
    }
    void find(int l, int r, int x){
        rem.clear(); all.clear();
        dec(1, 1, n, l, r);
        for (auto [v, tl, tr]: all){
            find(v, tl, tr, x);
        }
    }
};

int main(){
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    
    int n, D; cin>>n>>D;
    vector<int> g[n + 1], p(n + 1);
    for (int i = 2; i <= n; i++){
        cin>>p[i]; p[i]++;
        g[p[i]].pb(i);
        g[i].pb(p[i]);
    }

    vector<vector<int>> dist(n + 1, vector<int>(n + 1));
    function<void(int, int, int&)> dfs = [&](int v, int pr, int& k){
        for (int i: g[v]){
            if (i == pr) continue;
            dist[k][i] = dist[k][v] + 1;
            dfs(i, v, k);
        }
    };
    
    for (int i = 1; i <= n; i++) dfs(i, 0, i);
    
    vector<int> d(n + 1), tin(n + 1), tout(n + 1);
    int timer = 0;
    function<void(int, int)> fill = [&](int v, int pr){
        tin[v] = ++timer;
        for (int i: g[v]){
            if (i == pr) continue;
            d[i] = d[v] + 1;
            fill(i, v);
        }
        tout[v] = timer;
    };
    d[1] = 1;
    fill(1, 0);
    
    ST T(n);
    vector<int> inv(n + 1);
    for (int i = 1; i <= n; i++){
        T.upd(tin[i], d[i]);
        inv[tin[i]] = i;
    }
    
    dsu F(n, d);
    vector<bool> ban(n + 1);
    
    int out = 0;
    while (true){
        auto [x, y] = T.t[1].ss;
        if (!x) break;
        out++;
        
        y = inv[y];
        
        int v = y;
        while (true){
            v = F.m[F.get(v)].ss;
            if (ban[v]) v = p[v];
            int s = D - 1 - (d[y] - d[v]);
            if (!v || s < 0) break;
            T.find(tin[v], tout[v], d[v] + s);
            for (int i: T.rem){
                int j = inv[i];
                ban[j] = 1;
                for (int t: g[j]){
                    if (ban[t]){
                        F.unite(t, j);
                    }
                }
            }
        }
    }
    
    cout<<out<<"\n";
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...