답안 #851622

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
851622 2023-09-20T09:36:15 Z TranGiaHuy1508 Travelling Trader (CCO23_day2problem2) C++17
11 / 25
161 ms 37824 KB
#include <bits/stdc++.h>
using namespace std;

void main_program();

signed main(){
	ios_base::sync_with_stdio(0); cin.tie(0);
	main_program();
}

#define int long long

const int inf = 1e15;

int n, K;
vector<int> v, par;
vector<vector<int>> adj;
vector<int> sum_child;
vector<int> max_path;

namespace K2 {
	struct Pair3{
		int path_11, path_00, path_1x;

		bool operator < (const Pair3 &P) const {
			if (path_11 != P.path_11) return (path_11 < P.path_11);
			if (path_00 != P.path_00) return (path_00 < P.path_00);
			return (path_1x < P.path_1x);
		}
	};

	vector<int> dp_00, dp_11, dp_1x, dp_0x;
	vector<int> trace_00, trace_11;
	vector<pair<int, int>> trace_1x;
	vector<int> trace_skip_1x;
	vector<Pair3> trace_0x;

	int f_00(int x, int p);
	int f_11(int x, int p);
	int f_0x(int x, int p);
	int f_1x(int x, int p);
	
	int f_00(int x, int p = -1){
		if (dp_00[x] >= 0) return dp_00[x];
		int opt = 0, best = -1;
		for (auto k: adj[x]){
			if (k == p) continue;
			int diff = f_11(k, x) - v[k];
			if (diff > opt){
				opt = diff;
				best = k;
			}
		}
		int ans = sum_child[x] + v[x] + opt;
		trace_00[x] = best;
		dp_00[x] = ans;

		return dp_00[x];
	}

	int f_11(int x, int p = -1){
		if (dp_11[x] >= 0) return dp_11[x];
		int opt = 0, best = -1;
		for (auto k: adj[x]){
			if (k == p) continue;
			int diff = f_00(k, x) - v[k];
			if (diff > opt){
				opt = diff;
				best = k;
			}
		}
		int ans = sum_child[x] + v[x] + opt;
		trace_11[x] = best;
		dp_11[x] = ans;

		return dp_11[x];
	}

	int f_0x(int x, int p = -1){
		if (dp_0x[x] >= 0) return dp_0x[x];
		vector<pair<int, int>> delta_11, delta_00, delta_1x;

		for (auto k: adj[x]){
			if (k == p) continue;
			delta_11.emplace_back(f_11(k, x) - v[k], k);
			delta_00.emplace_back(f_00(k, x) - v[k], k);
			delta_1x.emplace_back(f_1x(k, x) - v[k], k);
		}

		if (delta_00.empty()){
			dp_0x[x] = v[x];
			trace_0x[x] = Pair3{-1, -1, -1};
			return dp_0x[x];
		}

		sort(delta_11.begin(), delta_11.end(), greater<>());
		sort(delta_00.begin(), delta_00.end(), greater<>());
		sort(delta_1x.begin(), delta_1x.end(), greater<>());

		long unsigned maxlen = 3;

		delta_11.resize(min(delta_11.size(), maxlen));
		delta_00.resize(min(delta_00.size(), maxlen));
		delta_1x.resize(min(delta_1x.size(), maxlen));

		delta_00.emplace_back(0, -2);
		
		pair<int, Pair3> final_opt(0, Pair3{-1, -1, -1});
		for (auto p11: delta_11){
			for (auto p00: delta_00){
				if (p11.second == p00.second) continue;
				for (auto p1x: delta_1x){
					if (p11.second == p1x.second) continue;
					if (p00.second == p1x.second) continue;

					pair<int, Pair3> crr_opt = {
						p11.first + p00.first + p1x.first,
						Pair3{
							p11.second,
							p00.second,
							p1x.second
						}
					};

					final_opt = max(final_opt, crr_opt);
				}
			}
		}

		int ans = sum_child[x] + v[x] + final_opt.first;
		trace_0x[x] = final_opt.second;
		dp_0x[x] = ans;

		return dp_0x[x];
	}

	int f_1x(int x, int p = -1){
		if (dp_1x[x] >= 0) return dp_1x[x];
		vector<pair<int, int>> delta_00, delta_1x;
		for (auto k: adj[x]){
			if (k == p) continue;
			delta_00.emplace_back(f_00(k, x) - v[k], k);
			delta_1x.emplace_back(f_1x(k, x) - v[k], k);
		}

		if (delta_00.empty()){
			dp_1x[x] = v[x];
			trace_1x[x] = {-1, -1};
			return dp_1x[x];
		}

		sort(delta_00.begin(), delta_00.end(), greater<>());
		sort(delta_1x.begin(), delta_1x.end(), greater<>());

		long unsigned maxlen = 2;

		delta_00.resize(min(delta_00.size(), maxlen));
		delta_1x.resize(min(delta_1x.size(), maxlen));

		pair<int, pair<int, int>> final_opt(0, {-1, -1});
		for (auto p00: delta_00){
			for (auto p1x: delta_1x){
				if (p00.second == p1x.second) continue;

				pair<int, pair<int, int>> crr_opt = {
					p00.first + p1x.first,
					{
						p1x.second, p00.second
					}
				};

				final_opt = max(final_opt, crr_opt);
			}
		}

		int ans = sum_child[x] + v[x] + final_opt.first;
		trace_1x[x] = final_opt.second;

		for (auto k: adj[x]){
			if (k == p) continue;
			int newval = v[x] + f_0x(k, x);
			if (newval > ans){
				trace_skip_1x[x] = k;
				ans = newval;
			}
		}

		dp_1x[x] = ans;

		return dp_1x[x];
	}

	vector<int> trace;

	void tracing_00(int x, int p);
	void tracing_11(int x, int p);
	void tracing_0x(int x, int p);
	void tracing_1x(int x, int p);

	void tracing_00(int x, int p = -1){
		for (auto k: adj[x]){
			if (k == p) continue;
			if (k == trace_00[x]) continue;
			trace.push_back(k);
		}
		if (trace_00[x] >= 0){
			tracing_11(trace_00[x], x);
		}
		trace.push_back(x);
	}

	void tracing_11(int x, int p = -1){
		trace.push_back(x);
		if (trace_11[x] >= 0){
			tracing_00(trace_11[x], x);
		}
		for (auto k: adj[x]){
			if (k == p) continue;
			if (k == trace_11[x]) continue;
			trace.push_back(k);
		}
	}
	
	void tracing_0x(int x, int p = -1){
		if (trace_0x[x].path_11 >= 0){
			tracing_11(trace_0x[x].path_11, x);
		}

		trace.push_back(x);

		if (trace_0x[x].path_00 >= 0){
			tracing_00(trace_0x[x].path_00, x);
		}

		for (auto k: adj[x]){
			if (k == p) continue;
			if (k == trace_0x[x].path_11) continue;
			if (k == trace_0x[x].path_00) continue;
			if (k == trace_0x[x].path_1x) continue;
			trace.push_back(k);
		}

		if (trace_0x[x].path_1x >= 0){
			tracing_1x(trace_0x[x].path_1x, x);
		}
	}

	void tracing_1x(int x, int p = -1){
		trace.push_back(x);

		if (trace_skip_1x[x] >= 0){
			tracing_0x(trace_skip_1x[x], x);
			return;
		}

		if (trace_1x[x].second >= 0){
			tracing_00(trace_1x[x].second, x);
		}
		for (auto k: adj[x]){
			if (k == p) continue;
			if (k == trace_1x[x].first) continue;
			if (k == trace_1x[x].second) continue;
			trace.push_back(k);
		}
		if (trace_1x[x].first >= 0){
			tracing_1x(trace_1x[x].first, x);
		}
	}

	void solve(){
		dp_00.assign(n, -inf);
		dp_11.assign(n, -inf);
		dp_0x.assign(n, -inf);
		dp_1x.assign(n, -inf);

		trace_00.assign(n, -1);
		trace_11.assign(n, -1);
		trace_1x.assign(n, {-1, -1});
		trace_0x.assign(n, Pair3{-1, -1, -1});
		trace_skip_1x.assign(n, -1);

		cout << f_1x(0) << "\n";
		tracing_1x(0);
		cout << trace.size() << "\n";
		for (int i = 0; i < trace.size(); i++){
			cout << trace[i] + 1 << " \n"[i == trace.size() - 1];
		}
	}
}

void dfs(int x, int p = -1){
	par[x] = p;
	for (auto k: adj[x]){
		if (k == p) continue;
		dfs(k, x);
		sum_child[x] += v[k];
	}
}

void dfs_K1(int x, int p = -1){
	for (auto k: adj[x]){
		if (k == p) continue;
		dfs_K1(k, x);
		max_path[x] = max(max_path[x], max_path[k]);
	}
	max_path[x] += v[x];
}

void solve_K1(){
	dfs_K1(0);

	vector<int> trace;
	trace.push_back(0);
	int crr = 0;

	while (true){
		int opt = -1, nxt = -1;
		for (auto k: adj[crr]){
			if (k == par[crr]) continue;
			if (max_path[k] > opt){
				opt = max_path[k];
				nxt = k;
			}
		}
		if (opt < 0) break;
		trace.push_back(nxt);
		crr = nxt;
	}

	cout << max_path[0] << "\n";
	cout << trace.size() << "\n";
	for (int i = 0; i < trace.size(); i++){
		cout << trace[i] + 1 << " \n"[i == trace.size() - 1];
	}
}

vector<int> trace;
void dfs_K3_2(int x, int p);

void dfs_K3_1(int x, int p = -1){
	trace.push_back(x);
	for (auto k: adj[x]){
		if (k == p) continue;
		dfs_K3_2(k, x);
	}
}

void dfs_K3_2(int x, int p = -1){
	for (auto k: adj[x]){
		if (k == p) continue;
		dfs_K3_1(k, x);
	}
	trace.push_back(x);
}

void solve_K3(){
	cout << accumulate(v.begin(), v.end(), 0LL) << "\n";
	cout << n << "\n";
	dfs_K3_1(0);
	for (int i = 0; i < n; i++){
		cout << trace[i] + 1 << " \n"[i == n-1];
	}
}

void main_program(){
	cin >> n >> K;

	adj.resize(n);
	v.resize(n);
	par.resize(n);
	sum_child.resize(n);
	max_path.resize(n);

	for (int i = 0; i < n-1; i++){
		int x, y; cin >> x >> y;
		x--; y--;
		adj[x].push_back(y);
		adj[y].push_back(x);
	}

	for (int i = 0; i < n; i++) cin >> v[i];

	dfs(0);
	if (K == 1){
		solve_K1();
	}
	else if (K == 2){
		K2::solve();
	}
	else{
		solve_K3();
	}
}

Compilation message

Main.cpp: In function 'void K2::solve()':
Main.cpp:285:21: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<long long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  285 |   for (int i = 0; i < trace.size(); i++){
      |                   ~~^~~~~~~~~~~~~~
Main.cpp:286:36: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<long long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  286 |    cout << trace[i] + 1 << " \n"[i == trace.size() - 1];
      |                                  ~~^~~~~~~~~~~~~~~~~~~
Main.cpp: In function 'void solve_K1()':
Main.cpp:332:20: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<long long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  332 |  for (int i = 0; i < trace.size(); i++){
      |                  ~~^~~~~~~~~~~~~~
Main.cpp:333:35: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<long long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  333 |   cout << trace[i] + 1 << " \n"[i == trace.size() - 1];
      |                                 ~~^~~~~~~~~~~~~~~~~~~
# 결과 실행 시간 메모리 Grader output
1 Correct 0 ms 344 KB Output is correct
2 Correct 1 ms 348 KB Output is correct
3 Correct 88 ms 21836 KB Output is correct
4 Correct 99 ms 23120 KB Output is correct
5 Correct 109 ms 23120 KB Output is correct
6 Correct 87 ms 24040 KB Output is correct
7 Correct 64 ms 22820 KB Output is correct
8 Correct 81 ms 22472 KB Output is correct
9 Correct 154 ms 37824 KB Output is correct
10 Correct 130 ms 32200 KB Output is correct
11 Correct 72 ms 23376 KB Output is correct
12 Correct 1 ms 600 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 0 ms 344 KB Output is correct
2 Correct 0 ms 344 KB Output is correct
3 Incorrect 1 ms 344 KB Output isn't correct
4 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 0 ms 344 KB Output is correct
2 Correct 0 ms 344 KB Output is correct
3 Incorrect 1 ms 344 KB Output isn't correct
4 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 0 ms 344 KB Output is correct
2 Correct 0 ms 344 KB Output is correct
3 Incorrect 1 ms 344 KB Output isn't correct
4 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 2 ms 600 KB Output is correct
2 Correct 1 ms 600 KB Output is correct
3 Correct 2 ms 600 KB Output is correct
4 Correct 1 ms 600 KB Output is correct
5 Correct 1 ms 600 KB Output is correct
6 Correct 1 ms 604 KB Output is correct
7 Correct 0 ms 344 KB Output is correct
8 Correct 1 ms 344 KB Output is correct
9 Correct 1 ms 600 KB Output is correct
10 Correct 1 ms 600 KB Output is correct
11 Correct 1 ms 600 KB Output is correct
12 Correct 1 ms 600 KB Output is correct
13 Correct 1 ms 600 KB Output is correct
14 Correct 1 ms 600 KB Output is correct
15 Correct 1 ms 600 KB Output is correct
16 Correct 1 ms 600 KB Output is correct
17 Correct 1 ms 600 KB Output is correct
18 Correct 1 ms 600 KB Output is correct
19 Correct 1 ms 600 KB Output is correct
20 Correct 1 ms 600 KB Output is correct
21 Correct 1 ms 600 KB Output is correct
22 Correct 2 ms 600 KB Output is correct
23 Correct 1 ms 600 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 2 ms 600 KB Output is correct
2 Correct 1 ms 600 KB Output is correct
3 Correct 2 ms 600 KB Output is correct
4 Correct 1 ms 600 KB Output is correct
5 Correct 1 ms 600 KB Output is correct
6 Correct 1 ms 604 KB Output is correct
7 Correct 0 ms 344 KB Output is correct
8 Correct 1 ms 344 KB Output is correct
9 Correct 1 ms 600 KB Output is correct
10 Correct 1 ms 600 KB Output is correct
11 Correct 1 ms 600 KB Output is correct
12 Correct 1 ms 600 KB Output is correct
13 Correct 1 ms 600 KB Output is correct
14 Correct 1 ms 600 KB Output is correct
15 Correct 1 ms 600 KB Output is correct
16 Correct 1 ms 600 KB Output is correct
17 Correct 1 ms 600 KB Output is correct
18 Correct 1 ms 600 KB Output is correct
19 Correct 1 ms 600 KB Output is correct
20 Correct 1 ms 600 KB Output is correct
21 Correct 1 ms 600 KB Output is correct
22 Correct 2 ms 600 KB Output is correct
23 Correct 1 ms 600 KB Output is correct
24 Correct 111 ms 26300 KB Output is correct
25 Correct 109 ms 26304 KB Output is correct
26 Correct 115 ms 26292 KB Output is correct
27 Correct 134 ms 26304 KB Output is correct
28 Correct 123 ms 26292 KB Output is correct
29 Correct 114 ms 26096 KB Output is correct
30 Correct 96 ms 26816 KB Output is correct
31 Correct 124 ms 26012 KB Output is correct
32 Correct 115 ms 27104 KB Output is correct
33 Correct 115 ms 25984 KB Output is correct
34 Correct 108 ms 27056 KB Output is correct
35 Correct 73 ms 25800 KB Output is correct
36 Correct 99 ms 25408 KB Output is correct
37 Correct 103 ms 25920 KB Output is correct
38 Correct 86 ms 26824 KB Output is correct
39 Correct 136 ms 37820 KB Output is correct
40 Correct 125 ms 33692 KB Output is correct
41 Correct 115 ms 31176 KB Output is correct
42 Correct 161 ms 30428 KB Output is correct
43 Correct 126 ms 29128 KB Output is correct
44 Correct 130 ms 27580 KB Output is correct
45 Correct 73 ms 26568 KB Output is correct
46 Correct 70 ms 25796 KB Output is correct