제출 #369870

#제출 시각아이디문제언어결과실행 시간메모리
369870penguinhackerSjekira (COCI20_sjekira)C++14
110 / 110
369 ms24152 KiB
#include <bits/stdc++.h>
using namespace std;

#define ll long long
#define ar array

#define debug(x) cerr << "[" << #x << "] = [" << x << "]\n"

template<class A, class B> ostream& operator<< (ostream& out, pair<A, B> p) {
	return out << '[' << p.first << ", " << p.second << ']';
}

template<class T> ostream& operator<< (ostream& out, vector<T> v) {
	out << '[';
	for (int i = 0; i < v.size(); ++i) {
		if (i > 0) {
			out << ", ";
		}
		out << v[i];
	}
	return out << ']';
}

template<class T> ostream& operator<< (ostream& out, set<T> s) {
	return out << (vector<T>(s.begin(), s.end()));
}

const int mxN = 100000;
int n, p[mxN], mx[mxN], offset[mxN];
ll ans = 0;
set<pair<int, int>> adj[mxN];
set<ar<int, 3>> small;

int find(int i) {
	return i ^ p[i] ? p[i] = find(p[i]) : i;
}

void merge(int a, int b) {
	if (adj[a].size() < adj[b].size()) swap(a, b);
	p[b] = a;
	if (mx[a] >= mx[b]) {
		for (const pair<int, int>& p : adj[b]) {
			if (find(a) != find(p.second))
			adj[a].emplace(p.first + offset[b] - offset[a] + mx[a] - mx[b], p.second);
		}
	}
	else {
		offset[a] += mx[b] - mx[a];
		for (const pair<int, int>& p : adj[b]) {
			if (find(a) != find(p.second))
			adj[a].emplace(p.first + offset[b] - offset[a], p.second);
		}
	}
	set<pair<int, int>>().swap(adj[b]);
	mx[a] = max(mx[a], mx[b]);
	while(!adj[a].empty()) {
		pair<int, int> p = *adj[a].begin();
		if (find(p.second) != a && mx[a] + mx[find(p.second)] - offset[a] == p.first) break;
		adj[a].erase(adj[a].begin());
		if (find(p.second) != a) adj[a].emplace(mx[a] + mx[find(p.second)] - offset[a], find(p.second));
	}
	//debug(offset[a]);
	//debug(adj[a]);
	if (!adj[a].empty()) {
		pair<int, int> p = *adj[a].begin();
		assert(p.first + offset[a] == mx[a] + mx[find(p.second)]);
		small.insert({p.first + offset[a], a, p.second});
	}
}

int main() {
	ios::sync_with_stdio(0);
	cin.tie(0);
	cin >> n;
	for (int i = 0; i < n; ++i) {
		cin >> mx[i];
		p[i] = i;
	}
	for (int i = 1; i < n; ++i) {
		int a, b; cin >> a >> b, --a, --b;
		adj[a].emplace(mx[a] + mx[b], b);
		adj[b].emplace(mx[a] + mx[b], a);
		small.insert({mx[a] + mx[b], a, b});
	}
	while(!small.empty()) {
		ar<int, 3> t = *small.begin(); small.erase(small.begin());
		int a = find(t[1]), b = find(t[2]);
		if (a == b || t[0] != mx[a] + mx[b]) continue; // fraud...
		//cerr << a << " " << b << "\n";
		ans += t[0];
		assert(!adj[a].empty() && !adj[b].empty());
		merge(a, b);
	}
	cout << ans;
	return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...