Submission #699130

#TimeUsernameProblemLanguageResultExecution timeMemory
699130qwerasdfzxclRoad Closures (APIO21_roads)C++17
100 / 100
261 ms64640 KiB
#include "roads.h"
#include <bits/stdc++.h>

using namespace std;
typedef long long ll;
constexpr int INF1 = 1e9 + 100;
constexpr ll INF2 = 4e18;
struct Edge{
    int v, w, i;
    Edge(){}
    Edge(int _v, int _w, int _i): v(_v), w(_w), i(_i) {}
};

int n;
vector<pair<int, int>> adj[100100], dep;
list<Edge> G[100100];
int dcnt, in[100100], in2[100100], out[100100], pW[100100], deg[100100];
ll dp[100100][2];

int pt, k;
vector<int> V, on;

struct Node1{
    ll sum, c, mn;
    Node1(){}
    Node1(ll x): sum(x), c(1), mn(x) {}
    Node1(ll _s, ll _c, ll _mn): sum(_s), c(_c), mn(_mn) {}

    Node1 operator + (const Node1 &N) const{
        return Node1(sum + N.sum, c + N.c, min(mn, N.mn));
    }
};
struct Seg1{
    vector<Node1> tree;
    int sz;

    void update(int i, int l, int r, int p, ll x){
        if (p<l || r<p) return;
        if (l==r){
            tree[i] = Node1(x);
            return;
        }

        int m = (l+r)>>1;
        update(i<<1, l, m, p, x); update(i<<1|1, m+1, r, p, x);
        tree[i] = tree[i<<1] + tree[i<<1|1];
    }

    int _search_below(int i, int l, int r, ll x){
        if (tree[i].mn > x) return -1;
        if (l==r) return l;

        int m = (l+r)>>1;
        int ret = _search_below(i<<1|1, m+1, r, x);
        if (ret==-1) return _search_below(i<<1, l, m, x);
        return ret;
    }

    int _search_cnt(int i, int l, int r, int c){
        if (l==r) return l;

        int m = (l+r)>>1;
        if (tree[i<<1].c < c) return _search_cnt(i<<1|1, m+1, r, c - tree[i<<1].c);
        return _search_cnt(i<<1, l, m, c);
    }

    Node1 query(int i, int l, int r, int s, int e){
        if (r<s || e<l) return Node1(0, 0, INF2);
        if (s<=l && r<=e) return tree[i];

        int m = (l+r)>>1;
        return query(i<<1, l, m, s, e) + query(i<<1|1, m+1, r, s, e);
    }

    void init(int _sz){sz = _sz; tree.clear(); tree.resize(sz*4 + 2, Node1(0, 0, INF2));}
    void update(int p, ll x){
        //printf(" update %d %lld\n", p, x);
        update(1, 1, sz, p, x);
    }
    int count(ll x){
        int idx = _search_below(1, 1, sz, x);
        return query(1, 1, sz, 1, idx).c;
    }
    ll query(int c){
        if (c<=0) return 0;
        int idx = _search_cnt(1, 1, sz, c);
        //printf("  query %d -> %d\n", c, idx);
        return query(1, 1, sz, 1, idx).sum;
    }
}tree[100100];

struct Seg2{
    pair<int, int> tree[400400];
    int sz;
    void init(int n){
        sz = n;
        for (int i=sz;i<sz*2;i++) tree[i] = dep[i-sz];
        for (int i=sz-1;i;i--) tree[i] = min(tree[i<<1], tree[i<<1|1]);
    }

    pair<int, int> query(int l, int r){
        ++r;
        pair<int, int> ret = {INF1, 0};
        for (l+=sz, r+=sz;l<r;l>>=1, r>>=1){
            if (l&1) ret = min(ret, tree[l++]);
            if (r&1) ret = min(ret, tree[--r]);
        }
        return ret;
    }
}rmq;

void dfs0(int s, int pa = 0, int paw = 0){
    //printf("ok %d\n", s);
    on.push_back(s);

    pW[s] = paw;
    in[s] = ++dcnt;
    deg[s] = adj[s].size();

    in2[s] = dep.size();
    dep.emplace_back(dep[in2[pa]].first + 1, s);

    sort(adj[s].begin(), adj[s].end());
    for (auto &[w, v]:adj[s]) if (v!=pa){
        G[s].emplace_back(v, w, G[s].size() + 1);
        dfs0(v, s, w);
        dep.emplace_back(dep[in2[pa]].first + 1, s);
    }

    tree[s].init(G[s].size());
    out[s] = dcnt;
}

ll calc(int s, const vector<pair<ll, ll>> &C, int t){
    int cnt = 0;
    ll ret = 0;
    for (int i=0;i<(int)C.size();i++){
        auto [x, y] = C[i];
        if (x > y || tree[s].count(y-x) + i + 1 <= t){
            cnt++;
            ret += y-x;
        }
        else break;
    }

    ret += tree[s].query(t-cnt);

    //printf(" calc %d %d -> %d %lld\n", s, t, cnt, ret);

    return ret;
}

bool cmp(const pair<ll, ll> &x, const pair<ll, ll> &y) {return x.second-x.first < y.second-y.first;}

void dfs(int s){
    /*printf(" in %d\n", s);
    if (s==1){
        printf("G[1]:");
        for (auto &[x, y, z]:G[s]) printf(" %d", x);
        printf("\n");
    }*/
    auto iter = G[s].begin();
    vector<pair<ll, ll>> C;

    while(pt < (int)V.size() && in[V[pt]] <= out[s]){
        while (in[V[pt]] > out[iter->v]){
            tree[s].update(iter->i, iter->w);
            iter = G[s].erase(iter);
        }
        //printf(" s = %d pointing %d\n", s, iter->v);

        int v = V[pt];
        pt++;
        dfs(v);

        if (iter->v == v){
            C.emplace_back(dp[v][0], dp[v][1]);
        }
        else{
            ll val = min(dp[v][0], dp[v][1]);
            C.emplace_back(val, val + iter->w);
        }

        iter++;
    }

    while (iter!=G[s].end()){
        tree[s].update(iter->i, iter->w);
        iter = G[s].erase(iter);
    }
    /*if (s==1){
        printf("G[1]:");
        for (auto &[x, y, z]:G[s]) printf(" %d", x);
        printf("\n");
    }*/

    dp[s][0] = 0, dp[s][1] = pW[s];

    if (deg[s] <= k){
        for (auto &[x, y]:C){
            dp[s][0] += min(x, y);
            dp[s][1] += min(x, y);
        }
        //printf(" s = %d -> %lld %lld\n", s, dp[s][0], dp[s][1]);
        return;
    }

    for (auto &[x, y]:C){
        dp[s][0] += x;
        dp[s][1] += x;
    }

    sort(C.begin(), C.end(), cmp);

    dp[s][0] += calc(s, C, deg[s]-k);
    dp[s][1] += calc(s, C, deg[s]-k-1);

    //printf(" s = %d -> %lld %lld\n", s, dp[s][0], dp[s][1]);
}

int lca(int x, int y){
    assert(in2[x] < in2[y]);
    auto [d, s] = rmq.query(in2[x], in2[y]);
    return s;
}

bool cmp2(int x, int y){return in[x] < in[y];}

void getV(const vector<int> &on){
    pt = 1;
    V.clear();
    V.push_back(1);
    for (int i=0;i<(int)on.size();i++){
        V.push_back(on[i]);
        if (i) V.push_back(lca(on[i-1], on[i]));
    }

    sort(V.begin(), V.end(), cmp2);
    V.erase(unique(V.begin(), V.end()), V.end());

    /*printf("-----------------------------\n");
    printf("k = %d\n", k);
    for (auto &x:V) printf("%d ", x);
    printf("\non: ");
    for (auto &x:on) printf("%d(%d) ", x, in[x]);
    printf("\n");*/
    //printf("%d %d %d\n", in[14], dep[in[14]].first, dep[in[14]].second);
}

void getnxt(vector<int> &on, int k){
    vector<int> nxt;
    for (auto &x:on) if (deg[x] >= k) nxt.push_back(x);
    swap(on, nxt);
}

std::vector<long long> minimum_closure_costs(int N, std::vector<int> U, std::vector<int> V, std::vector<int> W) {
    n = N;
    dep.clear();
    dep.emplace_back(0, 0);
    for (int i=1;i<=n;i++){
        adj[i].clear();
        G[i].clear();
    }
    dcnt = 0;

    vector<ll> ans(1);
    for (int i=0;i<n-1;i++){
        adj[U[i]+1].emplace_back(W[i], V[i]+1);
        adj[V[i]+1].emplace_back(W[i], U[i]+1);
        ans[0] += W[i];
    }

    dfs0(1);
    rmq.init(dep.size());

    for (k=1;k<=n-1;k++){
        getV(on);

        dfs(1);
        ans.push_back(dp[1][0]);

        getnxt(on, k+1);
    }

    return ans;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...