제출 #1335761

#제출 시각아이디문제언어결과실행 시간메모리
1335761danhdanh28032000도로 폐쇄 (APIO21_roads)C++20
100 / 100
762 ms27192 KiB
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int N = 1e5 + 3, mod = 998244353, inf = 2e9;
ll dp[N][2], lab[N], sum, TONG; 
vector <pair<int, int>> g[N];
vector <ll> value[N], pf[N];
int maxdeg, base = 500, root = -1, sz[N];
bool leaf[N];

int find(int u) {
	return (lab[u] < 0 ? u : lab[u] = find(lab[u]));
}

void unite(int r, int s) {
	if (lab[r] > lab[s]) swap(r, s);
	lab[r] += lab[s];
	lab[s] = r;
}

void prepc(int u, int p) {
	sz[u] = g[u].size();
	for (pair <int, int> P : g[u]) {
		if (P.first == p) continue;
		prepc(P.first, u);
		sz[u] = max(sz[u], sz[P.first]);
	}
}

void dfs(int u, int p, int &curK, vector <int> &w) {
	dp[u][0] = dp[u][1] = 0;
	vector <int> luu;	
	for (pair <int, int> P : g[u]) {
		int v = P.first;
		if (v == p) continue;
		if (curK <= sz[v]) dfs(v, u, curK, w);
		int b = dp[v][1] + w[P.second] - dp[v][0];
		if (b > 0) luu.push_back(b);
		dp[u][1] += dp[v][0];
		dp[u][0] += dp[v][0];
	}
	int sz = luu.size();
	if (sz < curK) {
		ll add = 0;
		for (int vl : luu) add += vl;
		dp[u][0] += add;
		dp[u][1] += add;
	}
	else if (sz != 0) {
		nth_element(luu.begin(), luu.end() - curK, luu.end());
		ll add = 0;
		int MIN = luu.back();
		for (int i = sz - 1; i >= sz - curK; --i) add += luu[i], MIN = min(MIN, luu[i]);
		dp[u][0] += add;
		dp[u][1] += add - MIN;	
	}
}

void PREPC(int u, int p) {
	sz[u] = g[u].size() + value[u].size();
	for (pair <int, int> P : g[u]) {
		if (P.first == p) continue;
		PREPC(P.first, u);
		sz[u] = max(sz[u], sz[P.first]);
	}
}

void DFS(int u, int p, int &curK, vector <int> &w) {
	dp[u][0] = dp[u][1] = 0;
	vector <ll> luu;
	
	for (pair <int, int> P : g[u]) {
		int v = P.first;
		if (v == p) continue;
		if (curK <= sz[v]) DFS(v, u, curK, w);
		ll giatri = dp[v][1] + w[P.second] - dp[v][0];
		if (giatri > 0) luu.push_back(giatri);
		dp[u][1] += dp[v][0];
		dp[u][0] += dp[v][0];
	}
	
	sort(luu.begin(), luu.end());
	int recur = curK, ind = value[u].size(), index = luu.size();	
	ll add = 0, MIN = 9e18;
	
	while (recur != 0 && index != 0) {
		ll csp = luu[--index];
	
		if (ind != 0) {
			int L = lower_bound(value[u].begin(), value[u].begin() + ind, csp) - value[u].begin();
			L = max(L, ind - recur);
			recur -= ind - L;
			add += pf[u][ind - 1];
			if (L != 0) add -= pf[u][L - 1];
			MIN = min(MIN, value[u][L]);
			ind = L;
		}
		
		if (recur != 0) add += csp, recur--, MIN = min(MIN, csp);
	}
	
	if (recur != 0 && !pf[u].empty()) {
		int vitri = max(ind - recur, 0);
		if (ind != 0) add += pf[u][ind - 1];
		if (vitri != 0) add -= pf[u][vitri - 1];		
		MIN = min(MIN, value[u][vitri]);
		recur -= ind - vitri;
	}

	dp[u][0] += add;
	dp[u][1] += add;
	if (recur == 0) dp[u][1] -= MIN;
}

vector <ll> minimum_closure_costs(int n, vector <int> u, vector <int> v, vector <int> w) {	
	vector <ll> ans(n);
	
	for (int i = 0; i < n - 1; ++i) {
		g[u[i]].emplace_back(v[i], i);
		g[v[i]].emplace_back(u[i], i);
		sum += w[i];
		maxdeg = max({maxdeg, (int)g[u[i]].size(), (int)g[v[i]].size()});
		if ((int)g[u[i]].size() > base) root = u[i];
		if ((int)g[v[i]].size() > base) root = v[i];
	}
	
	prepc(0, 0);
	for (int k = 0; k < min(maxdeg, base + 1); ++k) {
		dfs(0, 0, k, w);
		ans[k] = dp[0][0];
	}
	
	for (int i = 0; i < n; ++i) lab[i] = -1;
	for (int i = 0; i < n - 1; ++i) {
		if ((int)g[u[i]].size() > base || (int)g[v[i]].size() > base) continue;
		TONG += w[i];	
		int r = find(u[i]), s = find(v[i]);
		if (r != s) unite(r, s);
	}
	
	for (int i = 0; i < n; ++i) g[i].clear();
	for (int i = 0; i < n - 1; ++i) {
		int x = find(u[i]);
		int y = find(v[i]);
		if (x != y) {
			g[x].emplace_back(y, i);
			g[y].emplace_back(x, i);
		}
	}	
	
	for (int i = 0; i < n; ++i) if ((int)g[i].size() == 1) leaf[i] = 1;
	
	for (int i = 0; i < n; ++i) {
		vector <pair <int, int>> keep;
		for (pair <int, int> p : g[i]) {
			if (leaf[p.first]) value[i].push_back(w[p.second]);
			else keep.emplace_back(p.first, p.second);
		}
		
		sort(value[i].begin(), value[i].end());
		ll prefix = 0;
		for (int val : value[i]) {
			prefix += val;
			pf[i].push_back(prefix);
		}
 		swap(keep, g[i]);
	}
	
	for (int i = 0; i < n; ++i) {
		int r = find(i);
		if (!leaf[r]) {
			root = r;
			break;
		}
	}
	
	PREPC(root, root);
	for (int k = base + 1; k < maxdeg; ++k) {
		DFS(root, root, k, w);
		ans[k] = dp[root][0] + TONG;
	}
	
	for (int i = 0; i < maxdeg; ++i) ans[i] = sum - ans[i];
   	return ans;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...