답안 #793970

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
793970 2023-07-26T08:24:42 Z 박상훈(#10058) Travelling Trader (CCO23_day2problem2) C++17
11 / 25
175 ms 34496 KB
#include <bits/stdc++.h>

using namespace std;
typedef long long ll;

int n;
vector<int> adj[200200];
ll a[200200];

namespace foo{

ll dp[200200];
int nxt[200200];

void dfs(int s, int pa = 0){
	dp[s] = a[s], nxt[s] = 0;
	for (auto &v:adj[s]) if (v!=pa){
		dfs(v, s);
		if (dp[s] < dp[v] + a[s]){
			dp[s] = dp[v] + a[s];
			nxt[s] = v;
		}
	}
}

void solve1(){
	dfs(1);
	printf("%lld\n", dp[1]);

	vector<int> V;
	for (int i=1;i;i=nxt[i]){
		V.push_back(i);
	}

	printf("%d\n", (int)V.size());
	for (auto &x:V) printf("%d ", x);
	printf("\n");
}

} // namespace foo


namespace bar{

constexpr ll INF = 4e18;

ll dp[200200][2][3];
int v1[200200], v2[200200][2], v3[200200][3], typ2[200200], typ3[200200], nxt[200200];
vector<int> G[200200];

void dfs(int s, int pa = 0){
	vector<pair<ll, int>> C1, C2, C3;
	for (int i=1;i<=3;i++){
		C1.emplace_back(0, -i);
		C2.emplace_back(0, -i);
		C3.emplace_back(0, -i);
	}

	ll sum = 0;
	int deg = 0;

	for (auto &v:adj[s]) if (v!=pa){
		G[s].push_back(v);
		dfs(v, s);
		sum += a[v];
		deg++;

		C1.emplace_back(dp[v][1][0] - a[v], v);
		C2.emplace_back(dp[v][0][2] - a[v], v);
		C3.emplace_back(dp[v][1][2], v);
	}

	sort(C1.begin(), C1.end(), greater<pair<ll, int>>());
	sort(C2.begin(), C2.end(), greater<pair<ll, int>>());
	sort(C3.begin(), C3.end(), greater<pair<ll, int>>());

	// calc 0 1
	dp[s][0][1] = a[s] + sum;
	dp[s][0][1] += C1[0].first;
	v1[s] = C1[0].second;

	if (deg==0) dp[s][0][1] = -INF;

	// calc 0 2
	dp[s][0][2] = -INF;
	for (int i=0;i<2;i++){
		for (int j=0;j<2;j++) if (C1[i].second != C2[j].second){
			ll val = a[s] + sum + C1[i].first + C2[j].first;

			if (dp[s][0][2] < val){
				dp[s][0][2] = val;
				typ2[s] = 2;

				v2[s][0] = C1[i].second;
				v2[s][1] = C2[j].second;
			}
		}
	}

	if (dp[s][0][2] < a[s] + C3[0].first){
		dp[s][0][2] = a[s] + C3[0].first;
		typ2[s] = 1;

		v2[s][0] = C3[0].second;
	}

	// calc 1 2
	dp[s][1][2] = -INF;

	for (int i=0;i<3;i++){
		for (int j=i+1;j<3;j++){
			for (int k=0;k<3;k++) if (C1[i].second != C2[k].second && C1[j].second != C2[k].second && C1[i].second > 0){
				ll val = a[s] + sum + C1[i].first + C1[j].first + C2[k].first;

				if (dp[s][1][2] < val){
					dp[s][1][2] = val;
					typ3[s] = 3;

					v3[s][0] = C1[i].second;
					v3[s][1] = C1[j].second;
					v3[s][2] = C2[k].second;
				}
			}
		}
	}

	for (int i=0;i<2;i++){
		for (int j=0;j<2;j++) if (C1[i].second != C3[j].second && C1[i].second > 0){
			ll val = a[s] + C1[i].first + C3[j].first;
			if (C1[i].second > 0) val += a[C1[i].second];

			if (dp[s][1][2] < val){
				dp[s][1][2] = val;
				typ3[s] = 2;

				v3[s][0] = C1[i].second;
				v3[s][1] = C3[j].second;
			}
		}
	}

	if (deg==0) dp[s][1][2] = -INF;

	// calc 1 0
	dp[s][1][0] = dp[s][0][1];

	if (dp[s][1][2] < dp[s][1][0]) dp[s][1][2] = dp[s][1][0], typ3[s] = 0;


	// printf("\nok %d done / sum = %lld / deg = %d / val = %lld\n", s, sum, deg, a[s]);
	// printf("0 1: %lld / v1 = %d\n", dp[s][0][1], v1[s]);
	// printf("1 0: %lld / v1 = %d\n", dp[s][1][0], v1[s]);
	// printf("0 2: %lld / typ2 = %d / v2 = %d %d\n", dp[s][0][2], typ2[s], v2[s][0], v2[s][1]);
	// printf("1 2: %lld / typ3 = %d / v3 = %d %d %d\n", dp[s][1][2], typ3[s], v3[s][0], v3[s][1], v3[s][2]);
}

pair<int, int> track(int s, int x, int y){
	// printf(" call %d %d %d\n", s, x, y);
	assert(s > 0);
	assert(dp[s][x][y] >= 0);

	if (x==1 && y==2 && typ3[s]==0) y = 0;

	if (x==0 && y==1){
		pair<int, int> ret = {s, s};

		if (v1[s] > 0){
			auto [l1, r1] = track(v1[s], 1, 0);
			nxt[s] = l1;
			ret.second = r1;

			G[s].erase(find(G[s].begin(), G[s].end(), v1[s]));
		}

		for (auto &v:G[s]){
			nxt[ret.second] = v;
			ret.second = v;
		}

		return ret;
	}

	if (x==0 && y==2){
		pair<int, int> ret = {s, s};

		if (typ2[s]==1){
			if (v2[s][0] > 0){
				auto [l1, r1] = track(v2[s][0], 1, 2);
				nxt[s] = l1;

				ret.second = r1;
			}

			return ret;
			
		}

		// else (typ2[s]==2)
		if (v2[s][0] > 0){
			auto [l1, r1] = track(v2[s][0], 1, 0);
			nxt[s] = l1;

			ret.second = r1;
			G[s].erase(find(G[s].begin(), G[s].end(), v2[s][0]));
		}

		if (v2[s][1] > 0){
			G[s].erase(find(G[s].begin(), G[s].end(), v2[s][1]));
		}

		for (auto &v:G[s]){
			nxt[ret.second] = v;
			v = ret.second;
		}

		if (v2[s][1] > 0){
			nxt[ret.second] = v2[s][1];
			ret.second = v2[s][1];

			auto [l2, r2] = track(v2[s][1], 0, 2);
			ret.second = r2;
		}

		return ret;
	}

	if (x==1 && y==0){
		pair<int, int> ret = {s, s};

		if (v1[s] > 0){
			auto [l1, r1] = track(v1[s], 0, 1);
			nxt[r1] = s;
			ret.first = l1;

			G[s].erase(find(G[s].begin(), G[s].end(), v1[s]));
		}

		for (auto &v:G[s]){
			nxt[v] = ret.first;
			ret.first = v;
		}

		assert(ret.first != s);
		return ret;
	}

	assert(x==1 && y==2);

	pair<int, int> ret = {s, s};

	if (typ3[s]==2){
		if (v3[s][0] > 0){
			auto [l1, r1] = track(v3[s][0], 0, 1);
			ret.first = l1;
			nxt[r1] = s;

			G[s].erase(find(G[s].begin(), G[s].end(), v3[s][0]));
		}

		if (v3[s][1] > 0){
			auto [l2, r2] = track(v3[s][1], 1, 2);
			nxt[s] = l2;
			ret.second = r2;
		}

		assert(ret.first != s);
		return ret;
	}

	//typ3[s] == 3
	if (v3[s][0] > 0){
		auto [l1, r1] = track(v3[s][0], 0, 1);
		ret.first = l1;
		nxt[r1] = s;

		G[s].erase(find(G[s].begin(), G[s].end(), v3[s][0]));
	}

	if (v3[s][1] > 0){
		auto [l2, r2] = track(v3[s][1], 1, 0);
		nxt[s] = l2;
		ret.second = r2;

		G[s].erase(find(G[s].begin(), G[s].end(), v3[s][1]));
	}

	if (v3[s][2] > 0){
		G[s].erase(find(G[s].begin(), G[s].end(), v3[s][2]));
	}

	for (auto &v:G[s]){
		nxt[ret.second] = v;
		ret.second = v;
	}

	if (v3[s][2] > 0){
		auto [l3, r3] = track(v3[s][2], 0, 2);
		nxt[ret.second] = l3;
		ret.second = r3;
	}

	assert(ret.first != s);
	return ret;
}

void solve2(){
	dfs(1);

	ll ans = dp[1][0][2];
	printf("%lld\n", ans);

	track(1, 0, 2);

	vector<int> V;
	for (int i=1;i;i=nxt[i]) V.push_back(i);

	printf("%d\n", (int)V.size());
	for (auto &x:V) printf("%d ", x);
	printf("\n");
}

} // namespace bar

namespace baz{

void dfs1(int s, int pa);
void dfs2(int s, int pa);

vector<int> V;

void dfs1(int s, int pa){
	V.push_back(s);
	for (auto &v:adj[s]) if (v!=pa) dfs2(v, s);
}

void dfs2(int s, int pa){
	for (auto &v:adj[s]) if (v!=pa) dfs1(v, s);
	V.push_back(s);
}

void solve3(){
	dfs1(1, 0);

	ll sum = 0;
	for (int i=1;i<=n;i++) sum += a[i];

	printf("%lld\n", sum);
	printf("%d\n", n);
	for (auto &x:V) printf("%d ", x);
	printf("\n");
}

} // namespace baz

int main(){
	int k;
	scanf("%d %d", &n, &k);

	for (int i=1;i<=n-1;i++){
		int x, y;
		scanf("%d %d", &x, &y);
		adj[x].push_back(y);
		adj[y].push_back(x);
	}

	for (int i=1;i<=n;i++) scanf("%lld", a+i);

	if (k==1) foo::solve1();
	else if (k==2) bar::solve2();
	else baz::solve3();
}

Compilation message

Main.cpp: In function 'int main()':
Main.cpp:357:7: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
  357 |  scanf("%d %d", &n, &k);
      |  ~~~~~^~~~~~~~~~~~~~~~~
Main.cpp:361:8: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
  361 |   scanf("%d %d", &x, &y);
      |   ~~~~~^~~~~~~~~~~~~~~~~
Main.cpp:366:30: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
  366 |  for (int i=1;i<=n;i++) scanf("%lld", a+i);
      |                         ~~~~~^~~~~~~~~~~~~
# 결과 실행 시간 메모리 Grader output
1 Correct 5 ms 9684 KB Output is correct
2 Correct 5 ms 9764 KB Output is correct
3 Correct 115 ms 19992 KB Output is correct
4 Correct 104 ms 20068 KB Output is correct
5 Correct 122 ms 19948 KB Output is correct
6 Correct 109 ms 20804 KB Output is correct
7 Correct 72 ms 20664 KB Output is correct
8 Correct 87 ms 20268 KB Output is correct
9 Correct 175 ms 34496 KB Output is correct
10 Correct 151 ms 27368 KB Output is correct
11 Correct 72 ms 19860 KB Output is correct
12 Correct 5 ms 9684 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 6 ms 9684 KB Output is correct
2 Correct 5 ms 9684 KB Output is correct
3 Incorrect 5 ms 9684 KB total profit is not correct
4 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 6 ms 9684 KB Output is correct
2 Correct 5 ms 9684 KB Output is correct
3 Incorrect 5 ms 9684 KB total profit is not correct
4 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 6 ms 9684 KB Output is correct
2 Correct 5 ms 9684 KB Output is correct
3 Incorrect 5 ms 9684 KB total profit is not correct
4 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 5 ms 9812 KB Output is correct
2 Correct 5 ms 9836 KB Output is correct
3 Correct 6 ms 9784 KB Output is correct
4 Correct 6 ms 9812 KB Output is correct
5 Correct 5 ms 9812 KB Output is correct
6 Correct 5 ms 9812 KB Output is correct
7 Correct 5 ms 9720 KB Output is correct
8 Correct 5 ms 9620 KB Output is correct
9 Correct 7 ms 9812 KB Output is correct
10 Correct 5 ms 9812 KB Output is correct
11 Correct 5 ms 9796 KB Output is correct
12 Correct 5 ms 9816 KB Output is correct
13 Correct 5 ms 9812 KB Output is correct
14 Correct 5 ms 9812 KB Output is correct
15 Correct 8 ms 9724 KB Output is correct
16 Correct 5 ms 9812 KB Output is correct
17 Correct 5 ms 9812 KB Output is correct
18 Correct 5 ms 9812 KB Output is correct
19 Correct 7 ms 9812 KB Output is correct
20 Correct 5 ms 9812 KB Output is correct
21 Correct 5 ms 9812 KB Output is correct
22 Correct 5 ms 9812 KB Output is correct
23 Correct 5 ms 9812 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 5 ms 9812 KB Output is correct
2 Correct 5 ms 9836 KB Output is correct
3 Correct 6 ms 9784 KB Output is correct
4 Correct 6 ms 9812 KB Output is correct
5 Correct 5 ms 9812 KB Output is correct
6 Correct 5 ms 9812 KB Output is correct
7 Correct 5 ms 9720 KB Output is correct
8 Correct 5 ms 9620 KB Output is correct
9 Correct 7 ms 9812 KB Output is correct
10 Correct 5 ms 9812 KB Output is correct
11 Correct 5 ms 9796 KB Output is correct
12 Correct 5 ms 9816 KB Output is correct
13 Correct 5 ms 9812 KB Output is correct
14 Correct 5 ms 9812 KB Output is correct
15 Correct 8 ms 9724 KB Output is correct
16 Correct 5 ms 9812 KB Output is correct
17 Correct 5 ms 9812 KB Output is correct
18 Correct 5 ms 9812 KB Output is correct
19 Correct 7 ms 9812 KB Output is correct
20 Correct 5 ms 9812 KB Output is correct
21 Correct 5 ms 9812 KB Output is correct
22 Correct 5 ms 9812 KB Output is correct
23 Correct 5 ms 9812 KB Output is correct
24 Correct 114 ms 19876 KB Output is correct
25 Correct 138 ms 19908 KB Output is correct
26 Correct 120 ms 19884 KB Output is correct
27 Correct 150 ms 19840 KB Output is correct
28 Correct 153 ms 19904 KB Output is correct
29 Correct 138 ms 19744 KB Output is correct
30 Correct 103 ms 20468 KB Output is correct
31 Correct 122 ms 19768 KB Output is correct
32 Correct 130 ms 20628 KB Output is correct
33 Correct 117 ms 19788 KB Output is correct
34 Correct 113 ms 20576 KB Output is correct
35 Correct 86 ms 20416 KB Output is correct
36 Correct 116 ms 20072 KB Output is correct
37 Correct 121 ms 19828 KB Output is correct
38 Correct 134 ms 20664 KB Output is correct
39 Correct 155 ms 29004 KB Output is correct
40 Correct 123 ms 24420 KB Output is correct
41 Correct 155 ms 23008 KB Output is correct
42 Correct 148 ms 21992 KB Output is correct
43 Correct 170 ms 20812 KB Output is correct
44 Correct 149 ms 20612 KB Output is correct
45 Correct 88 ms 19716 KB Output is correct
46 Correct 88 ms 19620 KB Output is correct