#include "roads.h"
#include <bits/stdc++.h>
using namespace std;
#define ll long long
const int nx=1e5+5;
mt19937 rng(12345678);
ll n, dp[nx][2], sm[nx], tot, dsu[nx], deg[nx], in[nx], cnt, pa[nx], pw[nx];
vector<pair<ll, ll>> d[nx];
vector<ll> arc[nx];
set<ll> s;
int find(int x)
{
    if (dsu[x]==x) return x;
    return dsu[x]=find(dsu[x]);
}
struct treap
{
    struct node
    {
        ll sm, vl, key, sz;
        node *l, *r;
        node(ll vl): sm(vl), vl(vl), key(rng()), sz(1), l(0), r(0){}
    };
    typedef node* pnode;
    pnode rt;
    ll getsz(pnode x) {return x?x->sz:0;}
    ll getsm(pnode x) {return x?x->sm:0;}
    void update(pnode x)
    {
        if (!x) return;
        x->sz=1+getsz(x->l)+getsz(x->r);
        x->sm=x->vl+getsm(x->l)+getsm(x->r);
    }
    void merge(pnode l, pnode r, pnode &k)
    {
        if (!l||!r) return k=l?l:r, void();
        if (l->key>=r->key) merge(l->r, r, l->r), k=l;
        else merge(l, r->l, r->l), k=r;
        update(k);
    }
    void split(pnode &l, pnode &r, pnode k, int key)
    {
        if (!k) return l=r=0, void();
        if (1+getsz(k->r)<=key) split(l, k->l, k->l, key-(1+getsz(k->r))), r=k;
        else split(k->r, r, k->r, key), l=k;
        update(l), update(r);
    }
    void splitlowerbound(pnode &l, pnode &r, pnode k, ll vl)
    {
        if (!k) return l=r=0, void();
        if (k->vl>=vl) splitlowerbound(l, k->l, k->l, vl), r=k;
        else splitlowerbound(k->r, r, k->r, vl), l=k;
        update(l), update(r);
    }
    void insert(ll vl)
    {
        pnode p1, p2, p3=new node(vl);
        splitlowerbound(p1, p2, rt, vl);
        merge(p1, p3, rt);
        merge(rt, p2, rt);
    }
    void erase(ll vl)
    {
        pnode p1, p2, p3;
        splitlowerbound(p1, p2, rt, vl);
        split(p2, p3, p2, getsz(p2)-1);
        merge(p1, p3, rt);
    }
    ll query(ll k)
    {
        pnode p1, p2;
        split(p1, p2, rt, k);
        auto tmp=getsm(p2);
        merge(p1, p2, rt);
        return tmp;
    }
    void show(pnode x)
    {
        if (!x) return;
        show(x->l);
        cout<<x->vl<<' ';
        show(x->r);
    }
} t[nx];
void predfs(int u, int p)
{
    pa[u]=p;
    in[u]=++cnt;
    if (u!=pa[u]) t[pa[u]].insert(0);
    for (auto [v, w]:d[u]) if (v!=p) pw[v]=w, predfs(v, u);
}
void update(int u, int k)
{
    auto hd=find(u);
    if (hd!=u)
    {
        if (pa[hd]!=hd) sm[pa[hd]]-=dp[hd][0], t[pa[hd]].erase(dp[hd][1]-dp[hd][0]);
        dp[hd][0]-=max(dp[u][0], dp[u][1]);
        dp[hd][1]-=max(dp[u][0], dp[u][1]);
        dp[u][0]=sm[u]+t[u].query(k);
        dp[u][1]=sm[u]+t[u].query(k-1)+pw[u];
        dp[u][1]=max(dp[u][1], dp[u][0]);
        dp[hd][0]+=max(dp[u][0], dp[u][1]);
        dp[hd][1]+=max(dp[u][0], dp[u][1]);
        if (pa[hd]!=hd) sm[pa[hd]]+=dp[hd][0], t[pa[hd]].insert(dp[hd][1]-dp[hd][0]);
    }
    else
    {
        if (pa[hd]!=hd) sm[pa[hd]]-=dp[hd][0], t[pa[hd]].erase(dp[hd][1]-dp[hd][0]);
        dp[u][0]=sm[u]+t[u].query(k);
        dp[u][1]=sm[u]+t[u].query(k-1)+pw[u];
        dp[u][1]=max(dp[u][1], dp[u][0]);
        if (pa[hd]!=hd) sm[pa[hd]]+=dp[hd][0], t[pa[hd]].insert(dp[hd][1]-dp[hd][0]);
    }
    //cout<<"update "<<u<<' '<<dp[u][0]<<' '<<dp[u][1]<<'\n';
}
void archive(int u)
{
    s.erase(u);
    for (auto [v, w]:d[u]) if (v!=pa[u]) dsu[v]=u;
}
std::vector<long long> minimum_closure_costs(int N, std::vector<int> U,
                                             std::vector<int> V,
                                             std::vector<int> W) {
    n=N;
    vector<ll> mx(n);
    for (int i=0; i<n-1; i++) d[U[i]].push_back({V[i], W[i]}), d[V[i]].push_back({U[i], W[i]}), tot+=W[i], deg[U[i]]++, deg[V[i]]++;
    for (int i=0; i<n; i++) s.insert(i), arc[deg[i]].push_back(i), dsu[i]=i;
    predfs(0, 0);
    // minimum closure costs = total - maximum selected road
    for (int i=1; i<n; i++) 
    {
        vector<pair<ll, ll>> ord;
        for (auto u:s) ord.push_back({-in[u], u});
        sort(ord.begin(), ord.end());
        for (auto [_, u]:ord) update(u, i);
        mx[i]=dp[0][0];
        for (auto u:arc[i]) archive(u);
    }
    for (int i=0; i<n; i++) mx[i]=tot-mx[i];
    return mx;
}
| # | Verdict | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict | Execution time | Memory | Grader output | 
|---|
| Fetching results... |