Submission #923423

#TimeUsernameProblemLanguageResultExecution timeMemory
923423efedmrlrRoad Closures (APIO21_roads)C++17
100 / 100
289 ms54612 KiB
#include "roads.h"
// #pragma GCC optimize("O3,Ofast,unroll-loops")
// #pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include <bits/stdc++.h>

using namespace std;


#define lli long long int
#define MP make_pair
#define pb push_back
#define REP(i,n) for(int i = 0; (i) < (n); (i)++)
#define all(x) x.begin(), x.end()
#define rall(x) x.rbegin(), x.rend()

void fastio() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
}

mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
const double EPS = 0.00001;
const lli INF = 1e17;
const lli MX = 1e9 + 5;
const int M = 1e5+5;
const int ALPH = 26;
const int LGN = 25;
constexpr int MOD = 1e9+7;
int n,m,q,k;

vector<int> deg(M, 0);
vector<array<int,2> > p(M);
vector<vector<array<int,2> > > adj(M, vector<array<int,2> >());
vector<array<int,2> > deg_nodes;
vector<int> tin(M, 0), vis(M, -1);
int timer = 0;
set<array<int,2> > roots;
  
    struct Node {
        int lc, rc;
        lli x;
        lli y;
        int sz;
        lli sum;
        Node(int val) : lc(-1), rc(-1), x(val), y(rng()), sz(1), sum(val) {};

    };
     
    struct SegT {
        vector<Node> data;
        int root = -1;
        SegT() : root(-1) {};
        lli getSum(int t) {
            if(t == -1) return 0;
            return data[t].sum;
        }
        int getSize(int t) {
            if(t == -1) return 0;
            return data[t].sz;
        }
        void merge(int L, int R, int &t) {
            if(L == -1) {
                t = R;
                return;
            }
            if(R == -1) {
                t = L;
                return;
            }
            if(data[L].y > data[R].y) {
                t = L;
                merge(data[t].rc, R, data[t].rc);
            }
            else {
                t = R;
                merge(L, data[t].lc, data[t].lc);
            }
            data[t].sum = getSum(data[t].lc) + getSum(data[t].rc) + data[t].x;
            data[t].sum = min(data[t].sum, INF);
            data[t].sz = getSize(data[t].lc) + getSize(data[t].rc) + 1;
        }
        void split(int t, int &L, int &R, int x) {
            if(t == -1) {
                L = R = -1;
                return;
            }
            if(data[t].x <= x) {
                L = t;
                split(data[t].rc, data[t].rc, R, x);
            }
            else {
                R = t;
                split(data[t].lc, L, data[t].lc, x);
            }
            data[t].sum = getSum(data[t].lc) + getSum(data[t].rc) + data[t].x;
            data[t].sum = min(data[t].sum, INF);
            data[t].sz = getSize(data[t].lc) + getSize(data[t].rc) + 1;
        }
        void split_k(int t, int &L, int &R, int ks, int add = 0) {
            if(t == -1) {
                L = R = -1;
                return;
            }
            int sub = getSize(data[t].lc) + add;
            if(sub < ks) {
                L = t;
                split_k(data[t].rc, data[t].rc, R, ks, sub + 1);
            } 
            else {
                R = t;
                split_k(data[t].lc, L, data[t].lc, ks, add);
            }
            data[t].sum = getSum(data[t].lc) + getSum(data[t].rc) + data[t].x;
            data[t].sum = min(data[t].sum, INF);
            data[t].sz = getSize(data[t].lc) + getSize(data[t].rc) + 1;
        }
        void insert(int val) {
            // cout << "insert:" << val << "\n";
            // if(!(0ll < val && val < MX)) {
            //  return;
            // }
            data.pb(Node(val));
            int tl, tr;
            split(root, tl, tr, val);
            merge(tl, (int)data.size() - 1, root);
            merge(root, tr, root);
        }
 
        lli k_sum(int sm) {
            int tl, tr;
            split_k(root, tl, tr, sm);
            lli ans = getSum(tl);
            if(getSize(tl) < sm) {
                return INF;
            }
            // assert(getSize(tl) == sm);
            // cout << sm << " sum : " << ans << "\n";

            merge(tl, tr, root);
            return min(INF, ans);
        }
    };

vector<SegT> st;
void prec(int node, int par, int cst) {
    p[node] = array<int, 2>({par, cst});
    tin[node] = timer++;
    roots.insert({tin[node], node});

    for(auto itr = adj[node].begin(); itr != adj[node].end(); ) {
        int c = (*itr)[0];
        if(c == par) {
            itr = adj[node].erase(itr);
            continue;
        }
        prec((int)c, node, (*itr)[1]);
        itr++;
    }
}
array<lli, 2> dfs(int node) {
    vis[node] = k;
    // cout<<"dfs : " << node << endl;
    int csz = 0;
    vector<lli> A1, A2;
    vector<int> srt;
    array<lli,2> ret = {INF, INF};
    lli sum = 0;
    for(auto &c : adj[node]) {
        if(deg[c[0]] <= k) {
            break;
        }
        srt.pb(csz);
        csz++;
        auto tmp = dfs((int)c[0]);
        A1.pb(tmp[1] + c[1]); A2.pb(tmp[0]);
        sum += min(tmp[0], tmp[1] + c[1]);
    }
    auto comp = [&](const int &x, const int &y) {
        // assert(max(x, y) < min(A1.size(), A2.size()));
        return A1[x] - A2[x] < A1[y] - A2[y];
    };
    int cnt = deg[node] - k;
    // assert(cnt > 0);
    sort(all(srt), comp);
    vector<lli> cev(csz + 1);
    cev[0] = sum;
    
    for(int i = 1; i <= csz; i++) {
        cev[i] = cev[i - 1] - min(A1[srt[i - 1] ], A2[srt[i - 1]]) + A1[srt[i - 1]];
    }
    for(int i = 0; i <= min(cnt, csz); i++) {
        ret[0] = min(ret[0], cev[i] + st[node].k_sum(cnt - i));
        if(i < cnt) ret[1] = min(ret[1], cev[i] + st[node].k_sum(cnt - i - 1));
    }
    // cout << "node:"<<node<<" normal:"<<ret[0]<<" del par:"<<ret[1]<<"\n";
    // cout<<"cnt:"<<cnt<<"csz:"<<csz<<"\n";
    return ret;

}
lli solve_k() {
    // cout<<"k:"<<k<<"\n\n";
    while(deg_nodes.size() && deg_nodes.back()[0] == k) {
        auto cur = deg_nodes.back();
        // cout<<"erase:"<<cur[1]<<"\n";
        roots.erase({tin[cur[1]], cur[1]});
        st[p[cur[1]][0]].insert(p[cur[1]][1]);
        for(auto &c : adj[cur[1]]) {
            st[c[0]].insert(c[1]);
        }
        deg_nodes.pop_back();
    }
    lli ret = 0;
    for(auto &c : roots) {
        if(vis[c[1]] == k) continue;
        
        ret += dfs(c[1])[0];
    }
    return ret;
}


std::vector<long long> minimum_closure_costs(int N, std::vector<int> U,
  std::vector<int> V,
  std::vector<int> W) {
    n = N;
    lli sum = 0ll;
    for(int i = 0; i<n - 1; i++) {
      V[i]++; U[i]++;
      adj[U[i]].pb({V[i], W[i]});
      adj[V[i]].pb({U[i], W[i]});
      sum += W[i]; 
    }
    vector<lli> ans(n);
    deg_nodes.resize(n);
    for(int i = 1; i<=n; i++) {
        deg[i] = (int)adj[i].size();
        deg_nodes[i - 1] = {deg[i], i};
    }
    sort(rall(deg_nodes));
    auto comp = [](const array<int,2> &x, const array<int,2> &y) { return deg[x[0]] < deg[y[0]]; }; 
    for(int i = 1; i<=n; i++) {
        sort(rall(adj[i]), comp);
    }
    prec(1, 0, 0);
    ans[0] = sum;
    st.assign(n + 3, SegT());
    

    for(int i = 1; i<=n - 1; i++) {
        // cout<<"i:"<<i<<endl;
        k = i;
        ans[i] = solve_k();

    }
    return ans;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...