제출 #794248

#제출 시각아이디문제언어결과실행 시간메모리
794248qwerasdfzxclTravelling Trader (CCO23_day2problem2)C++17
11 / 25
194 ms34540 KiB
#include <bits/stdc++.h>

using namespace std;
typedef long long ll;

int n;
vector<int> adj[200200];
ll a[200200];

namespace foo{

ll dp[200200];
int nxt[200200];

void dfs(int s, int pa = 0){
	dp[s] = a[s], nxt[s] = 0;
	for (auto &v:adj[s]) if (v!=pa){
		dfs(v, s);
		if (dp[s] < dp[v] + a[s]){
			dp[s] = dp[v] + a[s];
			nxt[s] = v;
		}
	}
}

void solve1(){
	dfs(1);
	printf("%lld\n", dp[1]);

	vector<int> V;
	for (int i=1;i;i=nxt[i]){
		V.push_back(i);
	}

	printf("%d\n", (int)V.size());
	for (auto &x:V) printf("%d ", x);
	printf("\n");
}

} // namespace foo


namespace bar{

constexpr ll INF = 4e18;

ll dp[200200][2][3];
int v1[200200], v2[200200][2], v3[200200][3], typ2[200200], typ3[200200], nxt[200200];
vector<int> G[200200];

void dfs(int s, int pa = 0){
	vector<pair<ll, int>> C1, C2, C3;
	for (int i=1;i<=3;i++){
		C1.emplace_back(0, -i);
		C2.emplace_back(0, -i);
		C3.emplace_back(0, -i);
	}

	ll sum = 0;
	int deg = 0;

	for (auto &v:adj[s]) if (v!=pa){
		G[s].push_back(v);
		dfs(v, s);
		sum += a[v];
		deg++;

		C1.emplace_back(dp[v][1][0] - a[v], v);
		C2.emplace_back(dp[v][0][2] - a[v], v);
		C3.emplace_back(dp[v][1][2], v);
	}

	sort(C1.begin(), C1.end(), greater<pair<ll, int>>());
	sort(C2.begin(), C2.end(), greater<pair<ll, int>>());
	sort(C3.begin(), C3.end(), greater<pair<ll, int>>());

	// calc 0 1
	dp[s][0][1] = a[s] + sum;
	dp[s][0][1] += C1[0].first;
	v1[s] = C1[0].second;

	if (deg==0) dp[s][0][1] = -INF;

	// calc 0 2
	dp[s][0][2] = -INF;
	for (int i=0;i<2;i++){
		for (int j=0;j<2;j++) if (C1[i].second != C2[j].second){
			ll val = a[s] + sum + C1[i].first + C2[j].first;

			if (dp[s][0][2] < val){
				dp[s][0][2] = val;
				typ2[s] = 2;

				v2[s][0] = C1[i].second;
				v2[s][1] = C2[j].second;
			}
		}
	}

	if (dp[s][0][2] < a[s] + C3[0].first){
		dp[s][0][2] = a[s] + C3[0].first;
		typ2[s] = 1;

		v2[s][0] = C3[0].second;
	}

	// calc 1 2
	dp[s][1][2] = -INF;

	vector<pair<ll, int>> C4 = C1;
	for (auto &v:adj[s]) if (v!=pa) C4.emplace_back(0, n+v);
	sort(C4.begin(), C4.end(), greater<pair<ll, int>>());

	for (int i=0;i<3;i++){
		for (int j=0;j<3;j++){
			for (int k=0;k<3;k++){
				int rc4 = C4[i].second;
				if (rc4 > n) rc4 -= n;
				if (rc4==C1[j].second || rc4==C2[k].second || C1[j].second==C2[k].second) continue;
				if (rc4 < 0) continue;

				ll val = a[s] + sum + C4[i].first + C1[j].first + C2[k].first;

				if (dp[s][1][2] < val){
					dp[s][1][2] = val;
					typ3[s] = 3;

					v3[s][0] = C4[i].second;
					v3[s][1] = C1[j].second;
					v3[s][2] = C2[k].second;
				}
			}
		}
	}

	for (int i=0;i<2;i++){
		for (int j=0;j<2;j++) if (C4[i].second != C3[j].second && C4[i].second > 0){
			int rc4 = C4[i].second;
			if (rc4 > n) rc4 -= n;
			if (rc4==C3[j].second) continue;
			if (rc4 < 0) continue;

			ll val = a[s] + C4[i].first + C3[j].first;
			if (C4[i].second > 0) val += a[C4[i].second];

			if (dp[s][1][2] < val){
				dp[s][1][2] = val;
				typ3[s] = 2;

				v3[s][0] = C4[i].second;
				v3[s][1] = C3[j].second;
			}
		}
	}

	if (deg==0) dp[s][1][2] = -INF;

	// calc 1 0
	dp[s][1][0] = dp[s][0][1];

	if (dp[s][1][2] < dp[s][1][0]) dp[s][1][2] = dp[s][1][0], typ3[s] = 0;


	// printf("\nok %d done / sum = %lld / deg = %d / val = %lld\n", s, sum, deg, a[s]);
	// printf("0 1: %lld / v1 = %d\n", dp[s][0][1], v1[s]);
	// printf("1 0: %lld / v1 = %d\n", dp[s][1][0], v1[s]);
	// printf("0 2: %lld / typ2 = %d / v2 = %d %d\n", dp[s][0][2], typ2[s], v2[s][0], v2[s][1]);
	// printf("1 2: %lld / typ3 = %d / v3 = %d %d %d\n", dp[s][1][2], typ3[s], v3[s][0], v3[s][1], v3[s][2]);
}

pair<int, int> track(int s, int x, int y){
	// printf(" call %d %d %d\n", s, x, y);
	if (x==0 && y==0) return {s, s};
	assert(s > 0);
	assert(dp[s][x][y] >= 0);

	if (x==1 && y==2 && typ3[s]==0) y = 0;

	if (x==0 && y==1){
		pair<int, int> ret = {s, s};

		if (v1[s] > 0){
			auto [l1, r1] = track(v1[s], 1, 0);
			nxt[s] = l1;
			ret.second = r1;

			G[s].erase(find(G[s].begin(), G[s].end(), v1[s]));
		}

		for (auto &v:G[s]){
			nxt[ret.second] = v;
			ret.second = v;
		}

		return ret;
	}

	if (x==0 && y==2){
		pair<int, int> ret = {s, s};

		if (typ2[s]==1){
			if (v2[s][0] > 0){
				auto [l1, r1] = track(v2[s][0], 1, 2);
				nxt[s] = l1;

				ret.second = r1;
			}

			return ret;
			
		}

		// else (typ2[s]==2)
		if (v2[s][0] > 0){
			auto [l1, r1] = track(v2[s][0], 1, 0);
			nxt[s] = l1;

			ret.second = r1;
			G[s].erase(find(G[s].begin(), G[s].end(), v2[s][0]));
		}

		if (v2[s][1] > 0){
			G[s].erase(find(G[s].begin(), G[s].end(), v2[s][1]));
		}

		for (auto &v:G[s]){
			nxt[ret.second] = v;
			ret.second = v;
		}

		if (v2[s][1] > 0){
			nxt[ret.second] = v2[s][1];
			ret.second = v2[s][1];

			auto [l2, r2] = track(v2[s][1], 0, 2);
			ret.second = r2;
		}

		return ret;
	}

	if (x==1 && y==0){
		pair<int, int> ret = {s, s};

		if (v1[s] > 0){
			auto [l1, r1] = track(v1[s], 0, 1);
			nxt[r1] = s;
			ret.first = l1;

			G[s].erase(find(G[s].begin(), G[s].end(), v1[s]));
		}

		for (auto &v:G[s]){
			nxt[v] = ret.first;
			ret.first = v;
		}

		assert(ret.first != s);
		return ret;
	}

	assert(x==1 && y==2);

	pair<int, int> ret = {s, s};

	if (typ3[s]==2){
		if (v3[s][0] > 0){
			int rv3 = v3[s][0], ri = 0, rj = 1;
			if (v3[s][0] > n) rv3 = v3[s][0]-n, ri = 0, rj = 0;

			auto [l1, r1] = track(rv3, ri, rj);
			ret.first = l1;
			nxt[r1] = s;

			G[s].erase(find(G[s].begin(), G[s].end(), rv3));
		}

		if (v3[s][1] > 0){
			auto [l2, r2] = track(v3[s][1], 1, 2);
			nxt[s] = l2;
			ret.second = r2;
		}

		assert(ret.first != s);
		return ret;
	}

	//typ3[s] == 3
	if (v3[s][0] > 0){
		int rv3 = v3[s][0], ri = 0, rj = 1;
		if (v3[s][0] > n) rv3 = v3[s][0]-n, ri = 0, rj = 0;

		auto [l1, r1] = track(rv3, ri, rj);
		ret.first = l1;
		nxt[r1] = s;

		G[s].erase(find(G[s].begin(), G[s].end(), rv3));
	}

	if (v3[s][1] > 0){
		auto [l2, r2] = track(v3[s][1], 1, 0);
		nxt[s] = l2;
		ret.second = r2;

		G[s].erase(find(G[s].begin(), G[s].end(), v3[s][1]));
	}

	if (v3[s][2] > 0){
		G[s].erase(find(G[s].begin(), G[s].end(), v3[s][2]));
	}

	for (auto &v:G[s]){
		nxt[ret.second] = v;
		ret.second = v;
	}

	if (v3[s][2] > 0){
		auto [l3, r3] = track(v3[s][2], 0, 2);
		nxt[ret.second] = l3;
		ret.second = r3;
	}

	assert(ret.first != s);
	return ret;
}

void solve2(){
	for (int i=1;i<=n;i++) nxt[i] = 0, G[i].clear();
	dfs(1);

	ll ans = dp[1][0][2];
	printf("%lld\n", ans);

	track(1, 0, 2);

	ll fuck = 0;
	vector<int> V;
	for (int i=1;i;i=nxt[i]){
		V.push_back(i);
		fuck += a[i];

		// printf(" %d", i);
	} 
	// printf("\n");

	assert(ans == fuck);

	printf("%d\n", (int)V.size());
	for (auto &x:V) printf("%d ", x);
	printf("\n");
}

} // namespace bar

namespace baz{

void dfs1(int s, int pa);
void dfs2(int s, int pa);

vector<int> V;

void dfs1(int s, int pa){
	V.push_back(s);
	for (auto &v:adj[s]) if (v!=pa) dfs2(v, s);
}

void dfs2(int s, int pa){
	for (auto &v:adj[s]) if (v!=pa) dfs1(v, s);
	V.push_back(s);
}

void solve3(){
	dfs1(1, 0);

	ll sum = 0;
	for (int i=1;i<=n;i++) sum += a[i];

	printf("%lld\n", sum);
	printf("%d\n", n);
	for (auto &x:V) printf("%d ", x);
	printf("\n");
}

} // namespace baz


struct DSU{
	int path[100100];
	void init(int n){for (int i=1;i<=n;i++) path[i] = i;}
	int find(int s){
		if (s==path[s]) return s;
		return path[s] = find(path[s]);
	}

	int merge(int s, int v){
		s = find(s), v = find(v);
		if (s==v) return 0;
		path[v] = s;
		return 1;
	}
}dsu;

mt19937 seed(1557);
uniform_int_distribution<int> rng(0, 2147483647);
int getrand(int l, int r){return rng(seed) % (r-l+1) + l;}

void gen(){
	n = getrand(1, 5);

	dsu.init(n);
	for (int i=1;i<=n;i++) adj[i].clear();
	for (int i=1;i<=n;i++) a[i] = getrand(1, 1);

	for (int i=1;i<=n-1;i++){
		int x, y;
		do{
			x = getrand(1, n);
			y = getrand(1, n);
		}while(dsu.find(x)==dsu.find(y));

		dsu.merge(x, y);
		adj[x].push_back(y);
		adj[y].push_back(x);
	}
}

void stress(int tc){
	printf("------------------------\n");
	printf("Stress #%d\n", tc);

	// int ans = gen();
	gen();
	printf("Input:\n");
	printf("%d 2\n", n);

	for (int i=1;i<=n;i++){
		for (auto &j:adj[i]) if (j<i) printf("%d %d\n", j, i);
	}

	for (int i=1;i<=n;i++) printf("%lld ", a[i]);
	printf("\n");

	// int out = solve();

	// ans = naive(ans);
	// assert(ans < INF);

	bar::solve2();

	// printf("Answer: %d\n", ans);
	// printf("Output: %d\n", out);

	// if (!(out==ans && out!=-1)){
	// 	printf("------------------------\n");
	// 	printf("Stress #%d\n", tc);

	// 	printf("Input:\n");
	// 	printf("%s\n%s\n", s+1, t+1);

	// 	printf("Answer: %d\n", ans);
	// 	printf("Output: %d\n", out);
	// }

	// assert(out == ans && out != -1);

	if (tc%10000==0) printf("ok %d done\n", tc);
}

int main(){
	// for (int i=1;;i++) stress(i);

	int k;
	scanf("%d %d", &n, &k);

	for (int i=1;i<=n-1;i++){
		int x, y;
		scanf("%d %d", &x, &y);
		adj[x].push_back(y);
		adj[y].push_back(x);
	}

	for (int i=1;i<=n;i++) scanf("%lld", a+i);

	if (k==1) foo::solve1();
	else if (k==2) bar::solve2();
	else baz::solve3();
}

컴파일 시 표준 에러 (stderr) 메시지

Main.cpp: In function 'int main()':
Main.cpp:473:7: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
  473 |  scanf("%d %d", &n, &k);
      |  ~~~~~^~~~~~~~~~~~~~~~~
Main.cpp:477:8: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
  477 |   scanf("%d %d", &x, &y);
      |   ~~~~~^~~~~~~~~~~~~~~~~
Main.cpp:482:30: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
  482 |  for (int i=1;i<=n;i++) scanf("%lld", a+i);
      |                         ~~~~~^~~~~~~~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...