Submission #882968

#TimeUsernameProblemLanguageResultExecution timeMemory
882968tsumondaiTravelling Trader (CCO23_day2problem2)C++14
25 / 25
304 ms132392 KiB
#include <bits/stdc++.h>
using namespace std;

#define int long long
#define fi first
#define se second
#define pb push_back
#define mp make_pair
#define foru(i, l, r) for(int i = l; i <= r; i++)
#define ford(i, r, l) for(int i = r; i >= l; i--)
#define __TIME  (1.0 * clock() / CLOCKS_PER_SEC)

typedef pair<int, int> ii;
typedef pair<int, ii> iii;
typedef pair<ii, ii> iiii;

const int N = 1e6 + 5;

const int oo = 1e18, mod = 1e9 + 7;

int n, m, k, a[N];
vector<int> edges[N];


namespace subtask1 {
ii dist[N];
void dfs(int u, int par) {
	dist[u] = {0, u};
	for (int v : edges[u]) {
		if (v == par) continue;
		dfs(v, u);
		dist[u] = max(dist[u], {dist[v].fi, v});
	}
	dist[u].fi += a[u];
}
void process() {
	dfs(1, 0);
	cout << dist[1].fi << '\n';
	vector<int> res;
	int u = 1; res.push_back(u);
	while (dist[u].se != u) {
		u = dist[u].se;
		res.push_back(u);
	}
	cout << (int)res.size() << '\n';
	for (int v : res) cout << v << (v == res.back() ? '\n' : ' ');
}
}
namespace subtask2 {
int dp[N][2][2];
iii trace[N][2][2];
vector<int> res;

void dfs(int u, int par) {
	int total = 0;
	for (int v : edges[u]) {
		if (v == par) continue;
		dfs(v, u); total += a[v];
	}
	{	//cur = u and jump back to par
		ii Max = {0, 0};
		for (int v : edges[u]) {
			if (v == par) continue;
			Max = max(Max, {dp[v][1][1] + total, v});
		}
		dp[u][0][1] = Max.fi;
		trace[u][0][1] = {Max.se, {0, 0}};
	}
	{	//cur = child of u and jump back to u
		ii Max = {0, 0};
		for (int v : edges[u]) {
			if (v == par) continue;
			Max = max(Max, {dp[v][0][1] + total, v});
		}
		dp[u][1][1] = Max.fi;
		trace[u][1][1] = {Max.se, {0, 0}};
	}
	{	//cur = u and not jump back
		iii Max = {0, {0, 0}};
		vector<ii> p1 = {{0, 0}}, p2 = {{0, 0}};
		for (int v : edges[u]) {
			if (v == par) continue;
			Max = max(Max, {dp[v][1][0] + a[v], { -1, v}});
			p1.push_back({dp[v][1][1], v});
			p2.push_back({dp[v][0][0], v});
		}
		sort(p1.begin(), p1.end(), greater<ii>());
		sort(p2.begin(), p2.end(), greater<ii>());
		for (int i = 0; i < min(2LL, (int)p1.size()); i++)
			for (int j = 0; j < min(2LL, (int)p2.size()); j++) {
				if (p2[j].se != 0 && p2[j].se == p1[i].se) continue;
				Max = max(Max, {p1[i].fi + p2[j].fi + total, {p1[i].se, p2[j].se}});
			}
		dp[u][0][0] = Max.fi;
		trace[u][0][0] = {Max.se.fi, {Max.se.se, 0}};
	}
	{	//cur = child of u and not jump back
		pair<int, iii> Max = {0, {0, {0, 0}}};
		{
			iii Max2 = {0, {0, 0}};
			vector<ii> p1 = {{0, 0}}, p2 = {{0, 0}};
			for (int v : edges[u]) {
				if (v == par) continue;
				p1.push_back({dp[v][0][1], v});
				p2.push_back({dp[v][1][0], v});
			}
			sort(p1.begin(), p1.end(), greater<ii>());
			sort(p2.begin(), p2.end(), greater<ii>());
			for (int i = 0; i < min(2LL, (int)p1.size()); i++)
				for (int j = 0; j < min(2LL, (int)p2.size()); j++) {
					if (p2[j].se != 0 && p2[j].se == p1[i].se) continue;
					Max2 = max(Max2, {p1[i].fi + p2[j].fi + total, {p1[i].se, p2[j].se}});
				}
			Max = max(Max, {Max2.fi, { -1, Max2.se}});
		}

		vector<ii> p1 = {{0, 0}}, p2 = {{0, 0}}, p3 = {{0, 0}};
		for (int v : edges[u]) {
			if (v == par) continue;
			p1.push_back({dp[v][0][1], v});
			p2.push_back({dp[v][1][1], v});
			p3.push_back({dp[v][0][0], v});
		}
		sort(p1.begin(), p1.end(), greater<ii>());
		sort(p2.begin(), p2.end(), greater<ii>());
		sort(p3.begin(), p3.end(), greater<ii>());
		for (int i = 0; i < min(3LL, (int)p1.size()); i++)
			for (int j = 0; j < min(3LL, (int)p2.size()); j++)
				for (int k = 0; k < min(3LL, (int)p3.size()); k++) {
					if (p2[j].se != 0 && p2[j].se == p1[i].se) continue;
					if (p3[k].se != 0 && (p3[k].se == p1[i].se || p3[k].se == p2[j].se)) continue;
					Max = max(Max, {p1[i].fi + p2[j].fi + p3[k].fi + total, {p1[i].se, {p2[j].se, p3[k].se}}});
				}
		dp[u][1][0] = Max.fi;
		trace[u][1][0] = Max.se;
	}
}

void f(int u, int par, int t1, int t2) {
	if (!t1 && t2) {
		res.push_back(u);
		if (trace[u][0][1].fi) f(trace[u][0][1].fi, u, 1, 1);
		for (int v : edges[u]) {
			if (v == par) continue;
			if (v == trace[u][0][1].fi) continue;
			res.push_back(v);
		}
	}
	else if (t1 && t2) {
		for (int v : edges[u]) {
			if (v == par) continue;
			if (v == trace[u][1][1].fi) continue;
			res.push_back(v);
		}
		if (trace[u][1][1].fi) f(trace[u][1][1].fi, u, 0, 1);
		res.push_back(u);
	}
	else if (!t1 && !t2) {
		res.push_back(u);
		if (trace[u][0][0].fi == -1) {
			if (trace[u][0][0].se.fi) f(trace[u][0][0].se.fi, u, 1, 0);
			return;
		}
		if (trace[u][0][0].fi) f(trace[u][0][0].fi, u, 1, 1);
		for (int v : edges[u]) {
			if (v == par) continue;
			if (v == trace[u][0][0].se.fi || v == trace[u][0][0].fi) continue;
			res.push_back(v);
		}
		if (trace[u][0][0].se.fi) f(trace[u][0][0].se.fi, u, 0, 0);
	}
	else {
		if (trace[u][1][0].fi == -1) {
			for (int v : edges[u]) {
				if (v == par) continue;
				if (v == trace[u][1][0].se.fi || v == trace[u][1][0].se.se) continue;
				res.push_back(v);
			}
			if (trace[u][1][0].se.fi) f(trace[u][1][0].se.fi, u, 0, 1);
			res.push_back(u);
			if (trace[u][1][0].se.se) f(trace[u][1][0].se.se, u, 1, 0);
			return;
		}
		if (trace[u][1][0].fi) f(trace[u][1][0].fi, u, 0, 1);
		res.push_back(u);
		if (trace[u][1][0].se.fi) f(trace[u][1][0].se.fi, u, 1, 1);
		for (int v : edges[u]) {
			if (v == par) continue;
			if (v == trace[u][1][0].fi || v == trace[u][1][0].se.fi || v == trace[u][1][0].se.se) continue;
			res.push_back(v);
		}
		if (trace[u][1][0].se.se) f(trace[u][1][0].se.se, u, 0, 0);
	}
}

void process() {
	dfs(1, 0);
	f(1, 0, 0, 0);
	cout << dp[1][0][0] + a[1] << '\n';
	cout << (int)res.size() << '\n';
	for (int v : res) cout << v << (v == res.back() ? '\n' : ' ');
}
}

namespace subtask3 {

vector<int> res;
void dfs(int u, int par) {
	res.push_back(u);
	for (int v : edges[u]) {
		if (v == par) continue;
		for (int w : edges[v]) {
			if (w == u) continue;
			dfs(w, v);
		}
		res.push_back(v);
	}
}

void process() {
	int total = 0;
	for (int i = 1; i <= n; i++) total += a[i];
	dfs(1, 0);
	cout << total << '\n';
	cout << n << '\n';
	for (int v : res) cout << v << (v == res.back() ? '\n' : ' ');
}

}

signed main() {
	cin.tie(0)->sync_with_stdio(false);
	//freopen(".inp", "r", stdin);
	//freopen(".out", "w", stdout);
	cin >> n >> k;
	foru(i, 1, n - 1) {
		int u, v; cin >> u >> v;
		edges[u].pb(v);
		edges[v].pb(u);
	}
	foru(i, 1, n) cin >> a[i];
	if (k == 1) subtask1::process();
	if (k == 2) subtask2::process();
	if (k == 3) subtask3::process();
	cerr << "Time elapsed: " << __TIME << " s.\n";
	return 0;
}

// dont stop
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...