Submission #955756

#TimeUsernameProblemLanguageResultExecution timeMemory
955756caterpillowRoad Closures (APIO21_roads)C++17
100 / 100
1738 ms49648 KiB
#include <bits/stdc++.h>

using namespace std;

using ll = long long;
using pl = pair<ll, ll>;
#define vt vector
#define f first
#define s second
#define all(x) x.begin(), x.end() 
#define pb push_back
#define FOR(i, a, b) for (int i = (a); i < (b); i++)
#define ROF(i, a, b) for (int i = (b) - 1; i >= (a); i--)
#define F0R(i, b) FOR (i, 0, b)
#define endl '\n'
#define debug(x) do{auto _x = x; cerr << #x << " = " << _x << endl;} while(0)
const ll INF = 1e18;

random_device rd;
mt19937 mt(rd());
using ptr = struct Node*;

struct Node {
    int pri;
    ll val, agg;
    int sz;
    ptr l, r;

    Node(ll v) {
        val = agg = v;
        l = r = 0;
        pri = mt();
        sz = 1;
    }

    ~Node() {
        delete l;
        delete r;
    }
};

inline ll agg(ptr n) { return n ? n->agg : 0; }
inline ll val(ptr n) { return n ? n->val : 0; }
inline ll sz(ptr n) { return n ? n->sz : 0; }

ptr pull(ptr n) {
    if (!n) return n;
    n->sz = sz(n->l) + 1 + sz(n->r);
    n->agg = agg(n->l) + n->val + agg(n->r);
    return n;
}

pair<ptr, ptr> split(ptr n, ll v) {
    if (!n) return {n, n};
    if (v <= n->val) {
        auto [l, r] = split(n->l, v);
        n->l = r;
        return {l, pull(n)};
    } else {
        auto [l, r] = split(n->r, v);
        n->r = l;
        return {pull(n), r};
    }
}

pair<ptr, ptr> spliti(ptr n, int i) {
    if (!n) return {n, n};
    if (i <= sz(n->l)) {
        auto [l, r] = spliti(n->l, i);
        n->l = r;
        return {l, pull(n)};
    } else {
        auto [l, r] = spliti(n->r, i - sz(n->l) - 1);
        n->r = l;
        return {pull(n), r};
    }
}

ptr merge(ptr l, ptr r) {
    if (!l || !r) return l ? l : r;
    ptr t;
    if (l->pri > r->pri) l->r = merge(l->r, r), t = l;
    else r->l = merge(l, r->l), t = r;
    return pull(t);
}

void erase(ptr& n, ll v) {
    auto [l, rhs] = split(n, v);
    auto [m, r] = spliti(rhs, 1);
    n = merge(l, r);
}

void insert(ptr& n, ll v) {
    auto [l, r] = split(n, v);
    n = merge(l, merge(new Node(v), r));
}

ptr unite(ptr l, ptr r) {
    if (!l || !r) return l ? l : r;
    if (l->pri < r->pri) swap(l, r);
    auto [lt, rt] = split(r, l->val);
    l->l = unite(l->l, lt);
    l->r = unite(l->r, rt);
    return pull(l);
}

ll smol(ptr n) {
    if (!n) return INF;
    if (n->l) return smol(n->l);
    else return n->val;
}

/*

answer all k's in O(nodes with degree > k * log^2 n)

*/


int n;
vt<vt<pl>> adj;
vt<int> deg;
vt<int> ord;
vt<ll> ans;
vt<ptr> weights;

vt<bool> seen;

pl dfs(int u, int k, int par = -1) { // {take parent edge, dont take parent edge}

    seen[u] = true;
    vt<ll> diff;
    ll take_cost = 0;

    for (auto [v, w] : adj[u]) {
        if (v == par) continue;
        if (deg[v] > k) {
            auto [take, dont_take] = dfs(v, k, u);
            take_cost += dont_take;
            diff.pb(take - dont_take + w);
        } else break;
    }


    ll dont_take_cost = take_cost;
    ll req = deg[u] - k;

    sort(all(diff));

    // CURSED CODE INCOMING

    ptr sussy = nullptr;
    for (auto e : diff) sussy = merge(sussy, new Node(e));

    ptr& root = weights[u];

    root = unite(root, sussy);

        // merged into one fat treap
    
    auto [tl, tr] = split(root, 0); // automatically take all the negative ones
    int taken = sz(tl);   
    take_cost += agg(tl);
    dont_take_cost += agg(tl);

    if (taken < req) {
        auto [l, r] = spliti(tr, req - 1 - taken);

        take_cost += agg(l);
        dont_take_cost = take_cost + smol(r);

        tr = merge(l, r);
    }

    root = merge(tl, tr);


    for (auto e : diff) erase(root, e);

    // END CURSED CODE
    return {take_cost, dont_take_cost};
}


vt<ll> minimum_closure_costs(int _n, vt<int> u, vt<int> v, vt<int> w) {

    auto comp = [&] (int a, int b) {
        return deg[a] > deg[b];
    };
    auto comppl = [&] (pl a, pl b) {
        return deg[a.f] > deg[b.f];
    };

    n = _n;
    adj.resize(n);
    deg.resize(n);
    ord.resize(n);
    iota(all(ord), 0);
    weights.resize(n, nullptr);
    seen.resize(n);
    F0R (i, n - 1) {
        adj[u[i]].pb({v[i], w[i]});
        adj[v[i]].pb({u[i], w[i]});
        deg[u[i]]++;
        deg[v[i]]++;
        insert(weights[u[i]], w[i]);
        insert(weights[v[i]], w[i]);
    }
    sort(all(ord), comp);
    F0R (i, n) sort(all(adj[i]), comppl);
    ans.resize(n);

    ROF (k, 0, n) {

        for (int u : ord) {
            if (deg[u] <= k) break;
            if (!seen[u]) {
                ans[k] += dfs(u, k).s;
            }
        }

        // reset
        for (int u : ord) {
            if (deg[u] > k) seen[u] = false;
            else if (deg[u] == k) {
                for (auto [v, w] : adj[u]) erase(weights[v], w);
            } else break;
        }
    }

    return ans;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...