제출 #1215135

#제출 시각아이디문제언어결과실행 시간메모리
1215135madamadam3Cat in a tree (BOI17_catinatree)C++20
100 / 100
146 ms52804 KiB
#include <bits/stdc++.h>

using namespace std;

const int MAXN = 200'001, MAXLOG = 18;

int n, d;
vector<vector<int>> adj, up;
vector<int> depth, tin, tout;
int timer = 0;

void dfs(int u, int p) {
    tin[u] = timer++;

    if (u == 0) depth[u] = 0;
    else depth[u] = depth[p] + 1;

    up[u][0] = p;
    for (int i = 1; i < MAXLOG; i++) {
        up[u][i] = up[up[u][i - 1]][i-1];
    }

    for (auto &v : adj[u]) {
        if (v == p) continue;

        dfs(v, u);
    }

    tout[u] = timer;
}

bool is_ancestor(int u, int v) {
    return tin[u] <= tin[v] && tout[u] >= tout[v];
}

int find_lca(int u, int v) {
    if (is_ancestor(u, v)) return u;
    if (is_ancestor(v, u)) return v;

    for (int i = MAXLOG - 1; i >= 0; i--) {
        if (!is_ancestor(up[u][i], v)) {
            u = up[u][i];
        }
    }

    return up[u][0];
}

int dist(int u, int v) {
    return depth[u] + depth[v] - 2 * depth[find_lca(u, v)];
}

void task1() { // n <= 18
    int most = 0;
    for (int mask = 0; mask < (1 << n); mask++) {
        vector<int> active;
        for (int bit = 0; bit < n; bit++) if (mask & (1 << bit)) active.push_back(bit);

        bool valid = true;
        for (auto &u : active) {
            for (auto &v : active) {
                if (u == v) continue;
                valid = valid && dist(u, v) >= d;
            }
        }

        if (valid) most = max(most, (int) active.size());
    }

    cout << most;
}

void task2() { // n <= 1500
    vector<vector<int>> dm(n, vector<int>(n, 0));
    vector<set<int>> close(n); vector<bool> used(n, false);

    for (int i = 0; i < n; i++) {
        for (int j = 0; j < n; j++) {
            if (i == j) continue;
            dm[i][j] = dist(i, j);
            if (dm[i][j] >= 1 && dm[i][j] < d) {
                close[i].insert(j);
            }
        }
    }
    
    auto least_near = [&]() {
        int best = -1, amt = -1;
        for (int i = 0; i < n; i++) {
            if (used[i]) continue;
            if (depth[i] <= amt) continue;

            best = i; amt = depth[i];
        }

        return best;
    };

    int tl = 0;
    while (least_near() >= 0) {
        int cur = least_near();
        used[cur] = true;

        set<int> near = close[cur];
        for (auto &el : near) {
            used[el] = true;

            for (int i = 0; i < n; i++) {
                if (used[i]) continue;
                close[i].erase(el);
            }
        }

        tl++;
    }

    cout << tl;
}

struct Bag {
    multiset<int> s;
    int add = 0;
};

int ans = 0;
Bag dfs2(int u, int p) {
    int bigChild = -1;
    for (int v : adj[u])
        if (v != p && (bigChild == -1 ||
                       adj[bigChild].size() < adj[v].size()))
            bigChild = v;

    Bag big;
    if (bigChild != -1) {
        big = dfs2(bigChild, u);
        big.add++;                      // shift  +1  for the edge (u, bigChild)
    }

    for (int v : adj[u]) if (v != p && v != bigChild) {
        Bag small = dfs2(v, u);
        small.add++;                    // shift  +1  for edge (u, v)

        for (int x : small.s) {
            int real = x + small.add;
            big.s.insert(real - big.add);

            while (big.s.size() >= 2) {
                auto it1 = big.s.begin();
                auto it2 = next(it1);
                if ((*it1 + big.add) + (*it2 + big.add) >= d) break;
                big.s.erase(it1);       // drop the shallower one
            }
        }
    }

    // try to place u itself
    if (big.s.empty() || (*big.s.begin() + big.add) >= d)
        big.s.insert(-big.add);

    // still possible that we violated the rule with the new vertex
    while (big.s.size() >= 2) {
        auto it1 = big.s.begin();
        auto it2 = next(it1);
        if ((*it1 + big.add) + (*it2 + big.add) >= d) break;
        big.s.erase(it1);
    }

    return big;
}

void task3() {
    if (d == 1) { cout << n << '\n'; return; }

    Bag root = dfs2(0, -1);
    cout << root.s.size() << '\n';
}

int main() {
    cin.tie(0)->sync_with_stdio(0);

    cin >> n >> d;
    adj.assign(n, vector<int>());
    for (int i = 1; i < n; i++) {
        int xi; cin >> xi;
        adj[i].push_back(xi);
        adj[xi].push_back(i);
    }

    up.assign(n, vector<int>(MAXLOG, 0));
    tin.assign(n, -1); tout.assign(n, -1); depth.assign(n, -1);
    dfs(0, 0);

    task3();
    // if (n <= 18) task1();
    // else if (n <= 1500) task2();
    // else task3();

    return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...