# | 제출 시각 | 아이디 | 문제 | 언어 | 결과 | 실행 시간 | 메모리 |
---|---|---|---|---|---|---|---|
1184309 | anmattroi | Road Closures (APIO21_roads) | C++20 | 0 ms | 0 KiB |
#include "roads.h"
#include <bits/stdc++.h>
#define maxn 100005
#define fi first
#define se second
using namespace std;
using ii = pair<int, int>;
struct segmentTree {
int n;
vector<int> c, cnt;
vector<int64_t> sum;
void init(const vector<int> &orz) {
c.resize(1, 0);
for (int i : orz) c.emplace_back(i);
sort(c.begin(), c.end());
c.erase(unique(c.begin()+1, c.end()), c.end());
n = c.size()-1;
cnt.assign(4*n+1, 0); sum.assign(4*n+1, 0);
for (int i : orz) {
int p = lower_bound(c.begin()+1, c.end(), i) - c.begin();
upd(p, 1);
}
}
void up(int r) {
cnt[r] = cnt[r<<1] + cnt[r<<1|1];
sum[r] = sum[r<<1] + sum[r<<1|1];
}
void _upd(int u, int d, int r, int lo, int hi) {
if (lo == hi) {
cnt[r] += d;
sum[r] += c[lo] * d;
return;
}
int mid = (lo + hi) >> 1;
if (u <= mid) _upd(u, d, r<<1, lo, mid);
else _upd(u, d, r<<1|1, mid+1, hi);
up(r);
}
int _getCnt(int u, int v, int r, int lo, int hi) {
if (u <= lo && hi <= v) return cnt[r];
int mid = (lo + hi) >> 1;
return (u <= mid ? _getCnt(u, v, r<<1, lo, mid) : 0)
+ (v > mid ? _getCnt(u, v, r<<1|1, mid+1, hi) : 0);
}
int64_t _getSum(int u, int v, int r, int lo, int hi) {
if (u <= lo && hi <= v) return sum[r];
int mid = (lo + hi) >> 1;
return (u <= mid ? _getSum(u, v, r<<1, lo, mid) : 0)
+ (v > mid ? _getSum(u, v, r<<1|1, mid+1, hi) : 0);
}
int64_t _bfind(int k, int r, int lo, int hi) {
if (lo == hi) return 1LL * k * c[lo];
int mid = (lo + hi) >> 1, trai = cnt[r<<1];
if (k > trai) return sum[r<<1] + _bfind(k-trai, r<<1|1, mid+1, hi);
return _bfind(k, r<<1, lo, mid);
}
void upd(int u, int d) {
// assert(1 <= u && u <= n);
_upd(u, d, 1, 1, n);
}
int getCnt(int u, int v) {
// assert(1 <= u && v <= n && u <= v);
return _getCnt(u, v, 1, 1, n);
}
int64_t getSum(int u, int v) {
// assert(1 <= u && v <= n && u <= v);
return _getSum(u, v, 1, 1, n);
}
int64_t bfind(int k) {
// assert(k >= 0);
return _bfind(k, 1, 1, n);
}
} st[maxn];
int n, deg[maxn];
int64_t kq[maxn];
struct edge {int u, v, w;} edges[maxn];
int pe[maxn], par[maxn], cr;
set<int> nodes;
vector<ii> adj[maxn];
int64_t dp[maxn][2];
//dp[u][0] = at most k unblocked; may or may not block edge connecting to parent
//dp[u][1] = at most k unblocked; does block edge connecting to parent?
void pfs(int u, int dad) {
vector<int> orz;
for (auto [v, l] : adj[u])
if (v != dad) {
par[v] = u;
pe[v] = l;
orz.emplace_back(l);
pfs(v, u);
}
st[u].init(orz);
}
int64_t minCost(int u, int need, vector<int> &nho) {
// cout << u << ' ' << need << "\n";
if (need == 0) return 0;
int tr = nho.size();
segmentTree &cur = st[u];
for (int i = 0; i < nho.size(); i++) {
int p = lower_bound(cur.c.begin() + 1, cur.c.end(), nho[i]) - cur.c.begin() - 1;
if ((i+1) + (p == 0 ? 0 : cur.getCnt(1, p)) > need) {
tr = i;
break;
}
}
// cout << tr << ' ' << cur.cnt[1] << "\n";
assert(tr <= need);
int64_t sum = 0;
for (int i = 0; i < tr; i++) sum += nho[i];
return (tr == need ? 0 : cur.bfind(need-tr)) + sum;
}
void dfs(int u) {
int64_t S = 0;
vector<int> nho;
for (auto [v, l] : adj[u]) {
dfs(v);
S += dp[v][0];
nho.emplace_back(dp[v][1]-dp[v][0]);
}
sort(nho.begin(), nho.end());
if (par[u])
dp[u][0] = min((cr > 0 ? S + minCost(u, deg[u]-cr, nho) : LLONG_MAX), (dp[u][1] = (S + minCost(u, deg[u]-cr-1, nho) + pe[u])));
else
dp[u][0] = S + minCost(u, deg[u]-cr, nho);
}
vector<long long> solve() {
for (int i = 1; i < n; i++) {
auto [u, v, l] = edges[i];
adj[u].emplace_back(v, l);
adj[v].emplace_back(u, l);
}
pfs(1, 0);
vector<int> p(n+1, 0);
iota(p.begin() + 1, p.end(), 1);
sort(p.begin() + 1, p.end(), [&](int x, int y) {return deg[x] > deg[y];});
sort(edges + 1, edges + n, [&](const edge &x, const edge &y) {return min(deg[x.u], deg[x.v]) > min(deg[y.u], deg[y.v]);});
for (int i = 1; i <= n; i++) adj[i].clear();
vector<long long> ans(n, 0);
for (int o = n-1, it = 1, ptr = 1; o >= 0; o--) {
cr = o;
while (ptr <= n && deg[p[ptr]] > o) {
int u = p[ptr];
nodes.insert(u);
if (par[u]) {
int v = par[u];
int w = lower_bound(st[v].c.begin()+1, st[v].c.end(), pe[u]) - st[v].c.begin();
st[v].upd(w, -1);
}
++ptr;
}
while (it < n && min(deg[edges[it].u], deg[edges[it].v]) > o) {
auto [u, v, l] = edges[it];
if (par[u] == v) swap(u, v);
adj[u].emplace_back(v, l);
nodes.erase(v);
++it;
}
for (int i : nodes) dfs(i);
for (int i : nodes)
ans[o] += dp[i][0];
}
return ans;
}
vector<long long> minimum_closure_costs(int N, vector<int> U, vector<int> V, vector<int> W) {
n = N;
for (int i = 0; i < n-1; i++) {
edges[i+1] = edge{U[i]+1, V[i]+1, W[i]};
++deg[U[i]+1]; ++deg[V[i]+1];
}
return solve();
}
/*
5
0 1 1
0 2 4
0 3 3
2 4 2
*/