This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include "roads.h"
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int INF = 1e18;
const int MAXN = 1e5;
vector<pair<int, int>> adj[MAXN];
vector<int> prefSum[MAXN];
bool seen[MAXN];
// [withPar, withoutPar]
pair<int, int> dfs(int u, int p, int maxDeg) {
assert(!seen[u]);
seen[u] = true;
int coutInit = 0;
vector<int> swaps;
for (auto [v, w] : adj[u])
if (v != p) {
if ((int)adj[v].size() <= maxDeg) {
swaps.push_back(w);
continue;
}
auto [withPar, withoutPar] = dfs(v, u, maxDeg);
coutInit += withPar;
assert(coutInit < INF);
swaps.push_back(withoutPar + w - withPar);
}
sort(swaps.begin(), swaps.end());
// withPar :
int toRem = adj[u].size() - maxDeg;
int withPar = INF;
if (toRem <= 0)
withPar = coutInit;
int accu = coutInit;
if (toRem > 0)
for (int i = 0; i < (int)swaps.size(); ++i) {
accu += swaps[i];
if (i + 1 >= toRem) {
withPar = min(withPar, accu);
// break;
}
}
// withoutPar
toRem = adj[u].size() - maxDeg - (u != p);
int withoutPar = INF;
if (toRem <= 0)
withoutPar = coutInit;
accu = coutInit;
if (toRem > 0)
for (int i = 0; i < (int)swaps.size(); ++i) {
accu += swaps[i];
if (i + 1 >= toRem) {
withoutPar = min(withoutPar, accu);
// break;
}
}
return pair(withPar, withoutPar);
}
vector<int> minimum_closure_costs(signed N, vector<signed> U, vector<signed> V,
vector<signed> W) {
for (int i = 0; i < N - 1; ++i) {
adj[U[i]].emplace_back(V[i], W[i]);
adj[V[i]].emplace_back(U[i], W[i]);
}
vector<int> order(N);
iota(order.begin(), order.end(), 0LL);
sort(order.begin(), order.end(),
[&](int i, int j) { return adj[i].size() > adj[j].size(); });
for (int i = 0; i < N; ++i) {
sort(adj[i].begin(), adj[i].end(), [&](auto u, auto v) {
return adj[u.first].size() < adj[v.first].size();
});
prefSum[i].resize(adj[i].size() + 1);
for (int j = 0; j < (int)adj[i].size(); ++j)
prefSum[i][j + 1] = prefSum[i][j] + adj[i][j].second;
}
vector<int> ret(N);
ret[0] = accumulate(W.begin(), W.end(), 0LL);
for (int k = 1; k < N; ++k) {
for (int i : order) {
if ((int)adj[i].size() <= k)
break;
if (!seen[i])
ret[k] += dfs(i, i, k).first;
}
for (int i : order) {
if ((int)adj[i].size() <= k)
break;
seen[i] = false;
}
}
return ret;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |