#include "roads.h"
#define pb push_back
#define mp make_pair
#include <vector>
#include <set>
#include <algorithm>
using namespace std;
typedef long long ll;
typedef pair<int, ll> pi;
const ll INF = 1e15;
int N;
vector<vector<pi>> adj_list;
vector<int> subtree;
vector<int> deg;
// dp1[node][k]: At most k roads open at the subtree of node, road from node to par is either open or closed, whichever is more optimal
vector<vector<ll>> dp1;
// dp2[node][k]: At most k roads open at the subtree of node, road from node to par is closed (IT IS counted in cost)
// dp2 goes up to deg(node), after that it's just the weight
vector<vector<ll>> dp2;
vector<vector<ll>> best1, best2;
void merge(vector<ll>& a, vector<ll>& b) {
	if (a.size() < b.size()) {
		swap(a, b);
	}
	for (int i = 0; i < b.size(); i++) {
		a[i] = a[i] + b[i];
	}
}
bool degcmp(pi a, pi b) {
	return deg[a.first] < deg[b.first];
}
inline void erase(multiset<ll>& s, ll& t, ll v) {
	auto it = s.find(v);
	if (it != s.end()) {
		t -= v;
		s.erase(it);
	}
}
inline void insert(multiset<ll>& s, ll& t, ll v, int top) {
	t += v;
	s.insert(v);	
	if (top < s.size()) {
		t -= *prev(s.end());
		s.erase(prev(s.end()));
	}
}
// Calculate costs coming from children (both for d and d-1 children taken)
void calculate_best_cuts(int node) {
	multiset<ll> opt;
	ll sum = 0;
	best1[node].resize(deg[node] + 1);
	best2[node].resize(deg[node] + 1);
	sort(adj_list[node].begin(), adj_list[node].end(), degcmp);
	int start = 0;
	for (int d = 0; d <= deg[node]; d++) {
		int p = start;
		while (p < adj_list[node].size()) {
			int u;
			ll w;
			tie(u, w) = adj_list[node][p];
			if (d > 0) {
				erase(opt, sum, dp2[u][d - 1] - dp1[u][d - 1]);
			}
			p++;
		}
		p = start;
		for (int i = 0; i < start; i++) {
			erase(opt, sum, adj_list[node][i].second);
		}
		for (int i = 0; i < start; i++) {
			insert(opt, sum, adj_list[node][i].second, deg[node] - d);
		}
		while (p < adj_list[node].size()) {
			int u;
			ll w;
			tie(u, w) = adj_list[node][p];
			// Added cost of closing road from u to node
			insert(opt, sum, dp2[u][d] - dp1[u][d], deg[node] - d);
			if (deg[u] == d) {
				start = p + 1;
			}
			p++;
		}
		// Should have deg[node] - d elements (if at root, we have deg[node] - d + 1 elements)
		// If we are to open parent, we should look at closing deg[node] - d kids
		// If we are to close parent, we should look at closing deg[node] - d - 1 kids
		while (opt.size() > deg[node] - d) {
			erase(opt, sum, *prev(opt.end()));
		}
		if (d == 0) {
			if (node == 0) {
				best1[node][0] = sum;
			}
			else {
				best1[node][0] = INF;
			}
			best2[node][0] = sum;
		}
		else {
			best1[node][d] = sum;
			if (d != deg[node]) {
				best2[node][d] = sum - *prev(opt.end());
			}
			else {
				best2[node][d] = 0;
			}
		}
	}
}
void dfs(int node, int par, ll w) {
	deg[node] = adj_list[node].size();
	if (par != -1) {
		adj_list[node].erase(find(adj_list[node].begin(), adj_list[node].end(), mp(par, w)));
	}
		
	subtree[node] = 1;
	for (auto [neigh, w] : adj_list[node]) {
		dfs(neigh, node, w);
		subtree[node] += subtree[neigh];
	}
	calculate_best_cuts(node);
	vector<ll> dp;
	for (auto [neigh, w] : adj_list[node]) {
		merge(dp1[node], dp1[neigh]);
	}
	dp2[node].resize(deg[node] + 1);
	if (dp1[node].size() < deg[node] + 1) {
		dp1[node].resize(deg[node] + 1);
	}
	for (int d = 0; d <= deg[node]; d++) {
		dp2[node][d] = dp1[node][d] + best2[node][d] + w;
		dp1[node][d] = min(dp1[node][d] + best1[node][d], dp2[node][d]);
	}
}
vector<ll> minimum_closure_costs(int n, vector<int> U, vector<int> V,vector<int> W) {
	N = n;
	adj_list.resize(N);
	dp1.resize(N);
	dp2.resize(N);
	subtree.resize(N);
	deg.resize(N);
	best1.resize(N);
	best2.resize(N);
	for (int i = 0; i < N - 1; i++) {
		adj_list[U[i]].pb(mp(V[i], W[i]));
		adj_list[V[i]].pb(mp(U[i], W[i]));
	}
	dfs(0, -1, INF);
	while (dp1[0].size() < N) {
		dp1[0].pb(0);
	}
	return dp1[0];
}
| # | Verdict  | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict  | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict  | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict  | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict  | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict  | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict  | Execution time | Memory | Grader output | 
|---|
| Fetching results... |