제출 #851630

#제출 시각아이디문제언어결과실행 시간메모리
851630TranGiaHuy1508Travelling Trader (CCO23_day2problem2)C++17
25 / 25
456 ms100364 KiB
#include <bits/stdc++.h>
using namespace std;

void main_program();

signed main(){
	ios_base::sync_with_stdio(0); cin.tie(0);
	main_program();
}

#define int long long

const int inf = 1e15;

int n, K;
vector<int> v, par;
vector<vector<int>> adj;
vector<int> sum_child;
vector<int> max_path;

namespace K2 {
	struct Pair3{
		int path_11, path_00, path_1x;

		bool operator < (const Pair3 &P) const {
			if (path_11 != P.path_11) return (path_11 < P.path_11);
			if (path_00 != P.path_00) return (path_00 < P.path_00);
			return (path_1x < P.path_1x);
		}
	};

	vector<int> dp_00, dp_11, dp_1x, dp_0x;
	vector<int> trace_00, trace_11;
	vector<pair<int, int>> trace_1x;
	vector<int> trace_skip_1x;
	vector<Pair3> trace_0x;
	vector<pair<int, int>> trace_0x_bruh;
	vector<int> case_0x;

	int f_00(int x, int p);
	int f_11(int x, int p);
	int f_0x(int x, int p);
	int f_1x(int x, int p);
	
	int f_00(int x, int p = -1){
		if (dp_00[x] >= 0) return dp_00[x];
		int opt = 0, best = -1;
		for (auto k: adj[x]){
			if (k == p) continue;
			int diff = f_11(k, x) - v[k];
			if (diff > opt){
				opt = diff;
				best = k;
			}
		}
		int ans = sum_child[x] + v[x] + opt;
		trace_00[x] = best;
		dp_00[x] = ans;

		return dp_00[x];
	}

	int f_11(int x, int p = -1){
		if (dp_11[x] >= 0) return dp_11[x];
		int opt = 0, best = -1;
		for (auto k: adj[x]){
			if (k == p) continue;
			int diff = f_00(k, x) - v[k];
			if (diff > opt){
				opt = diff;
				best = k;
			}
		}
		int ans = sum_child[x] + v[x] + opt;
		trace_11[x] = best;
		dp_11[x] = ans;

		return dp_11[x];
	}

	int f_0x(int x, int p = -1){
		if (dp_0x[x] >= 0) return dp_0x[x];
		vector<pair<int, int>> delta_11, delta_00, delta_1x, delta_0x;

		for (auto k: adj[x]){
			if (k == p) continue;
			delta_11.emplace_back(f_11(k, x) - v[k], k);
			delta_00.emplace_back(f_00(k, x) - v[k], k);
			delta_1x.emplace_back(f_1x(k, x) - v[k], k);
			delta_0x.emplace_back(f_0x(k, x) - v[k], k);
		}

		if (delta_00.empty()){
			dp_0x[x] = v[x];
			trace_0x[x] = Pair3{-1, -1, -1};
			return dp_0x[x];
		}

		sort(delta_11.begin(), delta_11.end(), greater<>());
		sort(delta_00.begin(), delta_00.end(), greater<>());
		sort(delta_1x.begin(), delta_1x.end(), greater<>());
		sort(delta_0x.begin(), delta_0x.end(), greater<>());

		long unsigned maxlen = 3;

		delta_11.resize(min(delta_11.size(), maxlen));
		delta_00.resize(min(delta_00.size(), maxlen));
		delta_1x.resize(min(delta_1x.size(), maxlen));
		delta_0x.resize(min(delta_1x.size(), maxlen));

		delta_11.emplace_back(0, -2);
		delta_00.emplace_back(0, -3);
		delta_1x.emplace_back(0, -4);
		delta_0x.emplace_back(0, -5);
		
		pair<int, Pair3> final_opt(0, Pair3{-1, -1, -1});
		for (auto p11: delta_11){
			for (auto p00: delta_00){
				if (p11.second == p00.second) continue;
				for (auto p1x: delta_1x){
					if (p11.second == p1x.second) continue;
					if (p00.second == p1x.second) continue;

					pair<int, Pair3> crr_opt = {
						p11.first + p00.first + p1x.first,
						Pair3{
							p11.second,
							p00.second,
							p1x.second
						}
					};

					final_opt = max(final_opt, crr_opt);
				}
			}
		}

		int ans1 = sum_child[x] + v[x] + final_opt.first;

		pair<int, pair<int, int>> bruh_opt = {0, {-1, -1}};
		for (auto p11: delta_11){
			for (auto p0x: delta_0x){
				if (p11.second == p0x.second) continue;

				pair<int, pair<int, int>> crr_opt = {
					p11.first + p0x.first,
					{
						p11.second, p0x.second
					}
				};

				bruh_opt = max(bruh_opt, crr_opt);
			}
		}

		int ans2 = sum_child[x] + v[x] + bruh_opt.first;

		if (ans1 >= ans2){
			trace_0x[x] = final_opt.second;
			dp_0x[x] = ans1;
		}
		else{
			trace_0x_bruh[x] = bruh_opt.second;
			dp_0x[x] = ans2;
			case_0x[x] = 1;
		}

		return dp_0x[x];
	}

	int f_1x(int x, int p = -1){
		if (dp_1x[x] >= 0) return dp_1x[x];
		vector<pair<int, int>> delta_00, delta_1x;
		for (auto k: adj[x]){
			if (k == p) continue;
			delta_00.emplace_back(f_00(k, x) - v[k], k);
			delta_1x.emplace_back(f_1x(k, x) - v[k], k);
		}

		if (delta_00.empty()){
			dp_1x[x] = v[x];
			trace_1x[x] = {-1, -1};
			return dp_1x[x];
		}

		sort(delta_00.begin(), delta_00.end(), greater<>());
		sort(delta_1x.begin(), delta_1x.end(), greater<>());

		long unsigned maxlen = 2;

		delta_00.resize(min(delta_00.size(), maxlen));
		delta_1x.resize(min(delta_1x.size(), maxlen));

		delta_00.emplace_back(0, -2);
		delta_1x.emplace_back(0, -3);

		pair<int, pair<int, int>> final_opt(0, {-1, -1});
		for (auto p00: delta_00){
			for (auto p1x: delta_1x){
				if (p00.second == p1x.second) continue;

				pair<int, pair<int, int>> crr_opt = {
					p00.first + p1x.first,
					{
						p1x.second, p00.second
					}
				};

				final_opt = max(final_opt, crr_opt);
			}
		}

		int ans = sum_child[x] + v[x] + final_opt.first;
		trace_1x[x] = final_opt.second;

		for (auto k: adj[x]){
			if (k == p) continue;
			int newval = v[x] + f_0x(k, x);
			if (newval > ans){
				trace_skip_1x[x] = k;
				ans = newval;
			}
		}

		dp_1x[x] = ans;

		return dp_1x[x];
	}

	vector<int> trace;

	void tracing_00(int x, int p);
	void tracing_11(int x, int p);
	void tracing_0x(int x, int p);
	void tracing_1x(int x, int p);

	void tracing_00(int x, int p = -1){
		for (auto k: adj[x]){
			if (k == p) continue;
			if (k == trace_00[x]) continue;
			trace.push_back(k);
		}
		if (trace_00[x] >= 0){
			tracing_11(trace_00[x], x);
		}
		trace.push_back(x);
	}

	void tracing_11(int x, int p = -1){
		trace.push_back(x);
		if (trace_11[x] >= 0){
			tracing_00(trace_11[x], x);
		}
		for (auto k: adj[x]){
			if (k == p) continue;
			if (k == trace_11[x]) continue;
			trace.push_back(k);
		}
	}
	
	void tracing_0x(int x, int p = -1){
		if (case_0x[x] == 0){
			if (trace_0x[x].path_11 >= 0){
				tracing_11(trace_0x[x].path_11, x);
			}

			trace.push_back(x);

			if (trace_0x[x].path_00 >= 0){
				tracing_00(trace_0x[x].path_00, x);
			}

			for (auto k: adj[x]){
				if (k == p) continue;
				if (k == trace_0x[x].path_11) continue;
				if (k == trace_0x[x].path_00) continue;
				if (k == trace_0x[x].path_1x) continue;
				trace.push_back(k);
			}

			if (trace_0x[x].path_1x >= 0){
				tracing_1x(trace_0x[x].path_1x, x);
			}
		}
		else{
			for (auto k: adj[x]){
				if (k == p) continue;
				if (k == trace_0x_bruh[x].first) continue;
				if (k == trace_0x_bruh[x].second) continue;
				trace.push_back(k);
			}

			if (trace_0x_bruh[x].first >= 0){
				tracing_11(trace_0x_bruh[x].first, x);
			}
			trace.push_back(x);
			if (trace_0x_bruh[x].second >= 0){
				tracing_0x(trace_0x_bruh[x].second, x);
			}
		}
	}

	void tracing_1x(int x, int p = -1){
		trace.push_back(x);

		if (trace_skip_1x[x] >= 0){
			tracing_0x(trace_skip_1x[x], x);
			return;
		}

		if (trace_1x[x].second >= 0){
			tracing_00(trace_1x[x].second, x);
		}
		for (auto k: adj[x]){
			if (k == p) continue;
			if (k == trace_1x[x].first) continue;
			if (k == trace_1x[x].second) continue;
			trace.push_back(k);
		}
		if (trace_1x[x].first >= 0){
			tracing_1x(trace_1x[x].first, x);
		}
	}

	void solve(){
		dp_00.assign(n, -inf);
		dp_11.assign(n, -inf);
		dp_0x.assign(n, -inf);
		dp_1x.assign(n, -inf);

		trace_00.assign(n, -1);
		trace_11.assign(n, -1);
		trace_1x.assign(n, {-1, -1});
		trace_0x.assign(n, Pair3{-1, -1, -1});
		trace_skip_1x.assign(n, -1);
		trace_0x_bruh.assign(n, {-1, -1});
		case_0x.assign(n, 0);

		cout << f_1x(0) << "\n";
		tracing_1x(0);
		cout << trace.size() << "\n";
		for (int i = 0; i < trace.size(); i++){
			cout << trace[i] + 1 << " \n"[i == trace.size() - 1];
		}
	}
}

void dfs(int x, int p = -1){
	par[x] = p;
	for (auto k: adj[x]){
		if (k == p) continue;
		dfs(k, x);
		sum_child[x] += v[k];
	}
}

void dfs_K1(int x, int p = -1){
	for (auto k: adj[x]){
		if (k == p) continue;
		dfs_K1(k, x);
		max_path[x] = max(max_path[x], max_path[k]);
	}
	max_path[x] += v[x];
}

void solve_K1(){
	dfs_K1(0);

	vector<int> trace;
	trace.push_back(0);
	int crr = 0;

	while (true){
		int opt = -1, nxt = -1;
		for (auto k: adj[crr]){
			if (k == par[crr]) continue;
			if (max_path[k] > opt){
				opt = max_path[k];
				nxt = k;
			}
		}
		if (opt < 0) break;
		trace.push_back(nxt);
		crr = nxt;
	}

	cout << max_path[0] << "\n";
	cout << trace.size() << "\n";
	for (int i = 0; i < trace.size(); i++){
		cout << trace[i] + 1 << " \n"[i == trace.size() - 1];
	}
}

vector<int> trace;
void dfs_K3_2(int x, int p);

void dfs_K3_1(int x, int p = -1){
	trace.push_back(x);
	for (auto k: adj[x]){
		if (k == p) continue;
		dfs_K3_2(k, x);
	}
}

void dfs_K3_2(int x, int p = -1){
	for (auto k: adj[x]){
		if (k == p) continue;
		dfs_K3_1(k, x);
	}
	trace.push_back(x);
}

void solve_K3(){
	cout << accumulate(v.begin(), v.end(), 0LL) << "\n";
	cout << n << "\n";
	dfs_K3_1(0);
	for (int i = 0; i < n; i++){
		cout << trace[i] + 1 << " \n"[i == n-1];
	}
}

void main_program(){
	cin >> n >> K;

	adj.resize(n);
	v.resize(n);
	par.resize(n);
	sum_child.resize(n);
	max_path.resize(n);

	for (int i = 0; i < n-1; i++){
		int x, y; cin >> x >> y;
		x--; y--;
		adj[x].push_back(y);
		adj[y].push_back(x);
	}

	for (int i = 0; i < n; i++) cin >> v[i];

	dfs(0);
	if (K == 1){
		solve_K1();
	}
	else if (K == 2){
		K2::solve();
	}
	else{
		solve_K3();
	}
}

컴파일 시 표준 에러 (stderr) 메시지

Main.cpp: In function 'void K2::solve()':
Main.cpp:342:21: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<long long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  342 |   for (int i = 0; i < trace.size(); i++){
      |                   ~~^~~~~~~~~~~~~~
Main.cpp:343:36: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<long long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  343 |    cout << trace[i] + 1 << " \n"[i == trace.size() - 1];
      |                                  ~~^~~~~~~~~~~~~~~~~~~
Main.cpp: In function 'void solve_K1()':
Main.cpp:389:20: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<long long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  389 |  for (int i = 0; i < trace.size(); i++){
      |                  ~~^~~~~~~~~~~~~~
Main.cpp:390:35: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<long long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  390 |   cout << trace[i] + 1 << " \n"[i == trace.size() - 1];
      |                                 ~~^~~~~~~~~~~~~~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...