#include "roads.h"
#include <bits/stdc++.h>
using namespace std;
#define ll long long
const int nx=1e5+5;
mt19937 rng(12345678);
ll n, dp[nx][2], sm[nx], tot, dsu[nx], deg[nx], in[nx], cnt, pa[nx], pw[nx];
vector<pair<ll, ll>> d[nx];
vector<ll> arc[nx];
set<ll> s;
int find(int x)
{
if (dsu[x]==x) return x;
return dsu[x]=find(dsu[x]);
}
struct treap
{
struct node
{
ll sm, vl, key, sz;
node *l, *r;
node(ll vl): sm(vl), vl(vl), key(rng()), sz(1), l(0), r(0){}
};
typedef node* pnode;
pnode rt;
ll getsz(pnode x) {return x?x->sz:0;}
ll getsm(pnode x) {return x?x->sm:0;}
void update(pnode x)
{
if (!x) return;
x->sz=1+getsz(x->l)+getsz(x->r);
x->sm=x->vl+getsm(x->l)+getsm(x->r);
}
void merge(pnode l, pnode r, pnode &k)
{
if (!l||!r) return k=l?l:r, void();
if (l->key>=r->key) merge(l->r, r, l->r), k=l;
else merge(l, r->l, r->l), k=r;
update(k);
}
void split(pnode &l, pnode &r, pnode k, int key)
{
if (!k) return l=r=0, void();
if (1+getsz(k->r)<=key) split(l, k->l, k->l, key-(1+getsz(k->r))), r=k;
else split(k->r, r, k->r, key), l=k;
update(l), update(r);
}
void splitlowerbound(pnode &l, pnode &r, pnode k, ll vl)
{
if (!k) return l=r=0, void();
if (k->vl>=vl) splitlowerbound(l, k->l, k->l, vl), r=k;
else splitlowerbound(k->r, r, k->r, vl), l=k;
update(l), update(r);
}
void insert(ll vl)
{
pnode p1, p2, p3=new node(vl);
splitlowerbound(p1, p2, rt, vl);
merge(p1, p3, rt);
merge(rt, p2, rt);
}
void erase(ll vl)
{
pnode p1, p2, p3;
splitlowerbound(p1, p2, rt, vl);
split(p2, p3, p2, getsz(p2)-1);
merge(p1, p3, rt);
}
ll query(ll k)
{
pnode p1, p2;
split(p1, p2, rt, k);
auto tmp=getsm(p2);
merge(p1, p2, rt);
return tmp;
}
void show(pnode x)
{
if (!x) return;
show(x->l);
cout<<x->vl<<' ';
show(x->r);
}
} t[nx];
void predfs(int u, int p)
{
pa[u]=p;
in[u]=++cnt;
if (u!=pa[u]) t[pa[u]].insert(0);
for (auto [v, w]:d[u]) if (v!=p) pw[v]=w, predfs(v, u);
}
void update(int u, int k)
{
auto hd=find(u);
if (hd!=u)
{
if (pa[hd]!=hd) sm[pa[hd]]-=dp[hd][0], t[pa[hd]].erase(dp[hd][1]-dp[hd][0]);
dp[hd][0]-=max(dp[u][0], dp[u][1]);
dp[hd][1]-=max(dp[u][0], dp[u][1]);
dp[u][0]=sm[u]+t[u].query(k);
dp[u][1]=sm[u]+t[u].query(k-1)+pw[u];
dp[u][1]=max(dp[u][1], dp[u][0]);
dp[hd][0]+=max(dp[u][0], dp[u][1]);
dp[hd][1]+=max(dp[u][0], dp[u][1]);
if (pa[hd]!=hd) sm[pa[hd]]+=dp[hd][0], t[pa[hd]].insert(dp[hd][1]-dp[hd][0]);
}
else
{
if (pa[hd]!=hd) sm[pa[hd]]-=dp[hd][0], t[pa[hd]].erase(dp[hd][1]-dp[hd][0]);
dp[u][0]=sm[u]+t[u].query(k);
dp[u][1]=sm[u]+t[u].query(k-1)+pw[u];
dp[u][1]=max(dp[u][1], dp[u][0]);
if (pa[hd]!=hd) sm[pa[hd]]+=dp[hd][0], t[pa[hd]].insert(dp[hd][1]-dp[hd][0]);
}
//cout<<"update "<<u<<' '<<dp[u][0]<<' '<<dp[u][1]<<'\n';
}
void archive(int u)
{
s.erase(u);
for (auto [v, w]:d[u]) if (v!=pa[u]) dsu[v]=u;
}
std::vector<long long> minimum_closure_costs(int N, std::vector<int> U,
std::vector<int> V,
std::vector<int> W) {
n=N;
vector<ll> mx(n);
for (int i=0; i<n-1; i++) d[U[i]].push_back({V[i], W[i]}), d[V[i]].push_back({U[i], W[i]}), tot+=W[i], deg[U[i]]++, deg[V[i]]++;
for (int i=0; i<n; i++) s.insert(i), arc[deg[i]].push_back(i), dsu[i]=i;
predfs(0, 0);
// minimum closure costs = total - maximum selected road
for (int i=1; i<n; i++)
{
vector<pair<ll, ll>> ord;
for (auto u:s) ord.push_back({-in[u], u});
sort(ord.begin(), ord.end());
for (auto [_, u]:ord) update(u, i);
mx[i]=dp[0][0];
for (auto u:arc[i]) archive(u);
}
for (int i=0; i<n; i++) mx[i]=tot-mx[i];
return mx;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |