Submission #794337

#TimeUsernameProblemLanguageResultExecution timeMemory
794337qwerasdfzxclTravelling Trader (CCO23_day2problem2)C++17
25 / 25
441 ms180212 KiB
#include <bits/stdc++.h>

using namespace std;
typedef long long ll;

int n;
vector<int> adj[200200];
ll a[200200];

namespace foo{

ll dp[200200];
int nxt[200200];

void dfs(int s, int pa = 0){
	dp[s] = a[s], nxt[s] = 0;
	for (auto &v:adj[s]) if (v!=pa){
		dfs(v, s);
		if (dp[s] < dp[v] + a[s]){
			dp[s] = dp[v] + a[s];
			nxt[s] = v;
		}
	}
}

void solve1(){
	dfs(1);
	printf("%lld\n", dp[1]);

	vector<int> V;
	for (int i=1;i;i=nxt[i]){
		V.push_back(i);
	}

	printf("%d\n", (int)V.size());
	for (auto &x:V) printf("%d ", x);
	printf("\n");
}

} // namespace foo


namespace bar{

constexpr ll INF = 4e18;

ll dp[200200][2][3];
int v1[200200], v2[200200][2], v3[200200][3], typ2[200200], typ3[200200], nxt[200200];
vector<int> G[200200];

void dfs(int s, int pa = 0){
	vector<pair<ll, int>> C1, C2, C3, C6;
	for (int i=1;i<=3;i++){
		C1.emplace_back(0, -i);
		C2.emplace_back(0, -i);
		C3.emplace_back(0, -i);
		C6.emplace_back(0, -i);
	}

	ll sum = 0;
	int deg = 0;

	for (auto &v:adj[s]) if (v!=pa){
		G[s].push_back(v);
		dfs(v, s);
		sum += a[v];
		deg++;

		C1.emplace_back(dp[v][1][0] - a[v], v);
		C2.emplace_back(dp[v][0][2] - a[v], v);
		C3.emplace_back(dp[v][1][2], v);
		C6.emplace_back(dp[v][0][1] - a[v], v);
	}

	sort(C1.begin(), C1.end(), greater<pair<ll, int>>());
	sort(C2.begin(), C2.end(), greater<pair<ll, int>>());
	sort(C3.begin(), C3.end(), greater<pair<ll, int>>());
	sort(C6.begin(), C6.end(), greater<pair<ll, int>>());

	// calc 0 1
	dp[s][0][1] = a[s] + sum;
	dp[s][0][1] += C1[0].first;
	v1[s] = C1[0].second;

	if (deg==0) dp[s][0][1] = -INF;

	// calc 0 2
	dp[s][0][2] = -INF;
	for (int i=0;i<2;i++){
		for (int j=0;j<2;j++) if (C1[i].second != C2[j].second){
			ll val = a[s] + sum + C1[i].first + C2[j].first;

			if (dp[s][0][2] < val){
				dp[s][0][2] = val;
				typ2[s] = 2;

				v2[s][0] = C1[i].second;
				v2[s][1] = C2[j].second;
			}
		}
	}

	if (dp[s][0][2] < a[s] + C3[0].first){
		dp[s][0][2] = a[s] + C3[0].first;
		typ2[s] = 1;

		v2[s][0] = C3[0].second;
	}

	// calc 1 2
	dp[s][1][2] = -INF;

	vector<pair<ll, int>> C4 = C1;
	for (auto &v:adj[s]) if (v!=pa) C4.emplace_back(0, n+v);
	sort(C4.begin(), C4.end(), greater<pair<ll, int>>());

	for (int i=0;i<min(6, (int)C4.size());i++){
		for (int j=0;j<3;j++){
			for (int k=0;k<3;k++){
				int rc4 = C4[i].second;
				if (rc4 > n) rc4 -= n;
				if (rc4==C1[j].second || rc4==C2[k].second || C1[j].second==C2[k].second) continue;
				if (rc4 < 0) continue;

				ll val = a[s] + sum + C4[i].first + C1[j].first + C2[k].first;

				if (dp[s][1][2] < val){
					dp[s][1][2] = val;
					typ3[s] = 3;

					v3[s][0] = C4[i].second;
					v3[s][1] = C1[j].second;
					v3[s][2] = C2[k].second;
				}
			}
		}
	}

	for (int i=0;i<min(4, (int)C4.size());i++){
		for (int j=0;j<2;j++) if (C4[i].second != C3[j].second && C4[i].second > 0){
			int rc4 = C4[i].second;
			if (rc4 > n) rc4 -= n;
			if (rc4==C3[j].second) continue;
			if (rc4 < 0) continue;

			ll val = a[s] + C4[i].first + C3[j].first;
			if (C4[i].second > 0) val += a[rc4];

			if (dp[s][1][2] < val){
				dp[s][1][2] = val;
				typ3[s] = 2;

				v3[s][0] = C4[i].second;
				v3[s][1] = C3[j].second;
			}
		}
	}

	vector<pair<ll, int>> C5;

	for (int i=0;i<(int)C3.size();i++) if (C3[i].second > 0){
		C5.emplace_back(C3[i].first-a[C3[i].second], C3[i].second);
	}

	sort(C5.begin(), C5.end(), greater<pair<ll, int>>());
	for (int i=0;i<2;i++){
		for (int j=0;j<min(2, (int)C5.size());j++) if (C6[i].second != C5[j].second){
			ll val = a[s] + sum + C5[j].first + C6[i].first;
			if (dp[s][1][2] < val){
				dp[s][1][2] = val;
				typ3[s] = 1;

				v3[s][0] = C6[i].second;
				v3[s][1] = C5[j].second;

			}
		}
	}

	if (deg==0) dp[s][1][2] = -INF;

	// calc 1 0
	dp[s][1][0] = dp[s][0][1];

	if (dp[s][1][2] < dp[s][1][0]) dp[s][1][2] = dp[s][1][0], typ3[s] = 0;
}

pair<int, int> track(int s, int x, int y){
	if (x==0 && y==0) return {s, s};
	assert(s > 0);
	assert(dp[s][x][y] >= 0);

	if (x==1 && y==2 && typ3[s]==0) y = 0;

	if (x==0 && y==1){
		pair<int, int> ret = {s, s};

		if (v1[s] > 0){
			auto [l1, r1] = track(v1[s], 1, 0);
			nxt[s] = l1;
			ret.second = r1;

			G[s].erase(find(G[s].begin(), G[s].end(), v1[s]));
		}

		for (auto &v:G[s]){
			nxt[ret.second] = v;
			ret.second = v;
		}

		return ret;
	}

	if (x==0 && y==2){
		pair<int, int> ret = {s, s};

		if (typ2[s]==1){
			if (v2[s][0] > 0){
				auto [l1, r1] = track(v2[s][0], 1, 2);
				nxt[s] = l1;

				ret.second = r1;
			}

			return ret;
			
		}

		// else (typ2[s]==2)
		if (v2[s][0] > 0){
			auto [l1, r1] = track(v2[s][0], 1, 0);
			nxt[s] = l1;

			ret.second = r1;
			G[s].erase(find(G[s].begin(), G[s].end(), v2[s][0]));
		}

		if (v2[s][1] > 0){
			G[s].erase(find(G[s].begin(), G[s].end(), v2[s][1]));
		}

		for (auto &v:G[s]){
			nxt[ret.second] = v;
			ret.second = v;
		}

		if (v2[s][1] > 0){
			nxt[ret.second] = v2[s][1];
			ret.second = v2[s][1];

			auto [l2, r2] = track(v2[s][1], 0, 2);
			ret.second = r2;
		}

		return ret;
	}

	if (x==1 && y==0){
		pair<int, int> ret = {s, s};

		if (v1[s] > 0){
			auto [l1, r1] = track(v1[s], 0, 1);
			nxt[r1] = s;
			ret.first = l1;

			G[s].erase(find(G[s].begin(), G[s].end(), v1[s]));
		}

		for (auto &v:G[s]){
			nxt[v] = ret.first;
			ret.first = v;
		}

		assert(ret.first != s);
		return ret;
	}

	assert(x==1 && y==2);

	pair<int, int> ret = {s, s};

	if (typ3[s]==2){
		if (v3[s][0] > 0){
			int rv3 = v3[s][0], ri = 0, rj = 1;
			if (v3[s][0] > n) rv3 = v3[s][0]-n, ri = 0, rj = 0;

			auto [l1, r1] = track(rv3, ri, rj);
			ret.first = l1;
			nxt[r1] = s;

			G[s].erase(find(G[s].begin(), G[s].end(), rv3));
		}

		if (v3[s][1] > 0){
			auto [l2, r2] = track(v3[s][1], 1, 2);
			nxt[s] = l2;
			ret.second = r2;
		}

		assert(ret.first != s);
		return ret;
	}

	if (typ3[s]==1){
		assert(v3[s][1] > 0);
		G[s].erase(find(G[s].begin(), G[s].end(), v3[s][1]));

		if (v3[s][0] > 0){
			G[s].erase(find(G[s].begin(), G[s].end(), v3[s][0]));

			auto [l2, r2] = track(v3[s][0], 0, 1);
			nxt[r2] = ret.first;
			ret.first = l2;
		}

		for (auto &v:G[s]){
			nxt[v] = ret.first;
			ret.first = v;
		}



		auto [l1, r1] = track(v3[s][1], 1, 2);
		nxt[s] = l1;
		ret.second = r1;

		return ret;
	}

	//typ3[s] == 3
	if (v3[s][0] > 0){
		int rv3 = v3[s][0], ri = 0, rj = 1;
		if (v3[s][0] > n) rv3 = v3[s][0]-n, ri = 0, rj = 0;

		auto [l1, r1] = track(rv3, ri, rj);
		ret.first = l1;
		nxt[r1] = s;

		G[s].erase(find(G[s].begin(), G[s].end(), rv3));
	}

	if (v3[s][1] > 0){
		auto [l2, r2] = track(v3[s][1], 1, 0);
		nxt[s] = l2;
		ret.second = r2;

		G[s].erase(find(G[s].begin(), G[s].end(), v3[s][1]));
	}

	if (v3[s][2] > 0){
		G[s].erase(find(G[s].begin(), G[s].end(), v3[s][2]));
	}

	for (auto &v:G[s]){
		nxt[ret.second] = v;
		ret.second = v;
	}

	if (v3[s][2] > 0){
		auto [l3, r3] = track(v3[s][2], 0, 2);
		nxt[ret.second] = l3;
		ret.second = r3;
	}

	assert(ret.first != s);
	return ret;
}

ll solve2(){
	for (int i=1;i<=n;i++) nxt[i] = 0, G[i].clear();
	dfs(1);

	ll ans = dp[1][0][2];
	printf("%lld\n", ans);

	track(1, 0, 2);

	vector<int> V;
	for (int i=1;i;i=nxt[i]){
		V.push_back(i);
	} 

	printf("%d\n", (int)V.size());
	for (auto &x:V) printf("%d ", x);
	printf("\n");

	return ans;
}

} // namespace bar

namespace baz{

void dfs1(int s, int pa);
void dfs2(int s, int pa);

vector<int> V;

void dfs1(int s, int pa){
	V.push_back(s);
	for (auto &v:adj[s]) if (v!=pa) dfs2(v, s);
}

void dfs2(int s, int pa){
	for (auto &v:adj[s]) if (v!=pa) dfs1(v, s);
	V.push_back(s);
}

void solve3(){
	dfs1(1, 0);

	ll sum = 0;
	for (int i=1;i<=n;i++) sum += a[i];

	printf("%lld\n", sum);
	printf("%d\n", n);
	for (auto &x:V) printf("%d ", x);
	printf("\n");
}

} // namespace baz


int main(){
	int k;
	scanf("%d %d", &n, &k);

	for (int i=1;i<=n-1;i++){
		int x, y;
		scanf("%d %d", &x, &y);
		adj[x].push_back(y);
		adj[y].push_back(x);
	}

	for (int i=1;i<=n;i++) scanf("%lld", a+i);

	if (k==1) foo::solve1();
	else if (k==2) bar::solve2();
	else baz::solve3();
}

Compilation message (stderr)

Main.cpp: In function 'int main()':
Main.cpp:426:7: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
  426 |  scanf("%d %d", &n, &k);
      |  ~~~~~^~~~~~~~~~~~~~~~~
Main.cpp:430:8: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
  430 |   scanf("%d %d", &x, &y);
      |   ~~~~~^~~~~~~~~~~~~~~~~
Main.cpp:435:30: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
  435 |  for (int i=1;i<=n;i++) scanf("%lld", a+i);
      |                         ~~~~~^~~~~~~~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...