제출 #794332

#제출 시각아이디문제언어결과실행 시간메모리
794332qwerasdfzxclTravelling Trader (CCO23_day2problem2)C++17
25 / 25
441 ms180296 KiB
#include <bits/stdc++.h>

using namespace std;
typedef long long ll;

int n;
vector<int> adj[200200];
ll a[200200];

namespace foo{

ll dp[200200];
int nxt[200200];

void dfs(int s, int pa = 0){
	dp[s] = a[s], nxt[s] = 0;
	for (auto &v:adj[s]) if (v!=pa){
		dfs(v, s);
		if (dp[s] < dp[v] + a[s]){
			dp[s] = dp[v] + a[s];
			nxt[s] = v;
		}
	}
}

void solve1(){
	dfs(1);
	printf("%lld\n", dp[1]);

	vector<int> V;
	for (int i=1;i;i=nxt[i]){
		V.push_back(i);
	}

	printf("%d\n", (int)V.size());
	for (auto &x:V) printf("%d ", x);
	printf("\n");
}

} // namespace foo


namespace bar{

constexpr ll INF = 4e18;

ll dp[200200][2][3];
int v1[200200], v2[200200][2], v3[200200][3], typ2[200200], typ3[200200], nxt[200200];
vector<int> G[200200];

void dfs(int s, int pa = 0){
	vector<pair<ll, int>> C1, C2, C3, C6;
	for (int i=1;i<=3;i++){
		C1.emplace_back(0, -i);
		C2.emplace_back(0, -i);
		C3.emplace_back(0, -i);
		C6.emplace_back(0, -i);
	}

	ll sum = 0;
	int deg = 0;

	for (auto &v:adj[s]) if (v!=pa){
		G[s].push_back(v);
		dfs(v, s);
		sum += a[v];
		deg++;

		C1.emplace_back(dp[v][1][0] - a[v], v);
		C2.emplace_back(dp[v][0][2] - a[v], v);
		C3.emplace_back(dp[v][1][2], v);
		C6.emplace_back(dp[v][0][1] - a[v], v);
	}

	sort(C1.begin(), C1.end(), greater<pair<ll, int>>());
	sort(C2.begin(), C2.end(), greater<pair<ll, int>>());
	sort(C3.begin(), C3.end(), greater<pair<ll, int>>());
	sort(C6.begin(), C6.end(), greater<pair<ll, int>>());

	// calc 0 1
	dp[s][0][1] = a[s] + sum;
	dp[s][0][1] += C1[0].first;
	v1[s] = C1[0].second;

	if (deg==0) dp[s][0][1] = -INF;

	// calc 0 2
	dp[s][0][2] = -INF;
	for (int i=0;i<2;i++){
		for (int j=0;j<2;j++) if (C1[i].second != C2[j].second){
			ll val = a[s] + sum + C1[i].first + C2[j].first;

			if (dp[s][0][2] < val){
				dp[s][0][2] = val;
				typ2[s] = 2;

				v2[s][0] = C1[i].second;
				v2[s][1] = C2[j].second;
			}
		}
	}

	if (dp[s][0][2] < a[s] + C3[0].first){
		dp[s][0][2] = a[s] + C3[0].first;
		typ2[s] = 1;

		v2[s][0] = C3[0].second;
	}

	// calc 1 2
	dp[s][1][2] = -INF;

	vector<pair<ll, int>> C4 = C1;
	for (auto &v:adj[s]) if (v!=pa) C4.emplace_back(0, n+v);
	sort(C4.begin(), C4.end(), greater<pair<ll, int>>());

	// if (s==12){
	// 	for (int i=0;i<6;i++) printf("(%lld, %d) ", C4[i].first, C4[i].second);
	// 	printf("\n");
	// 	for (int i=0;i<3;i++) printf("(%lld, %d) ", C1[i].first, C1[i].second);
	// 	printf("\n");
	// 	for (int i=0;i<3;i++) printf("(%lld, %d) ", C2[i].first, C2[i].second);
	// 	printf("\n");
	// }

	for (int i=0;i<min(6, (int)C4.size());i++){
		for (int j=0;j<3;j++){
			for (int k=0;k<3;k++){
				int rc4 = C4[i].second;
				if (rc4 > n) rc4 -= n;
				if (rc4==C1[j].second || rc4==C2[k].second || C1[j].second==C2[k].second) continue;
				if (rc4 < 0) continue;

				ll val = a[s] + sum + C4[i].first + C1[j].first + C2[k].first;

				if (dp[s][1][2] < val){
					dp[s][1][2] = val;
					typ3[s] = 3;

					v3[s][0] = C4[i].second;
					v3[s][1] = C1[j].second;
					v3[s][2] = C2[k].second;
				}
			}
		}
	}

	// if (s==5) printf(" (%lld, %d) / (%lld, %d)\n", C3[0].first, C3[0].second, C3[1].first, C3[1].second);
	// if (s==5) printf(" (%lld, %d) / (%lld, %d)\n", C4[0].first, C4[0].second, C4[1].first, C4[1].second);

	for (int i=0;i<min(4, (int)C4.size());i++){
		for (int j=0;j<2;j++) if (C4[i].second != C3[j].second && C4[i].second > 0){
			int rc4 = C4[i].second;
			if (rc4 > n) rc4 -= n;
			if (rc4==C3[j].second) continue;
			if (rc4 < 0) continue;

			ll val = a[s] + C4[i].first + C3[j].first;
			if (C4[i].second > 0) val += a[rc4];

			// printf("ok fuck %lld + %lld + %lld -> %lld\n", a[s], C4[i].first, C3[j].first, val);

			if (dp[s][1][2] < val){
				dp[s][1][2] = val;
				typ3[s] = 2;

				v3[s][0] = C4[i].second;
				v3[s][1] = C3[j].second;
			}
		}
	}

	vector<pair<ll, int>> C5;

	for (int i=0;i<(int)C3.size();i++) if (C3[i].second > 0){
		C5.emplace_back(C3[i].first-a[C3[i].second], C3[i].second);

		// ll val = a[s] + sum - a[C3[i].second] + C3[i].first;
		// if (dp[s][1][2] < val){
		// 	dp[s][1][2] = val;
		// 	typ3[s] = 1;

		// 	v3[s][0] = C3[i].second;

		// }
	}

	sort(C5.begin(), C5.end(), greater<pair<ll, int>>());
	for (int i=0;i<2;i++){
		for (int j=0;j<min(2, (int)C5.size());j++) if (C6[i].second != C5[j].second){
			ll val = a[s] + sum + C5[j].first + C6[i].first;
			if (dp[s][1][2] < val){
				dp[s][1][2] = val;
				typ3[s] = 1;

				v3[s][0] = C6[i].second;
				v3[s][1] = C5[j].second;

			}
		}
	}

	if (deg==0) dp[s][1][2] = -INF;

	// calc 1 0
	dp[s][1][0] = dp[s][0][1];

	if (dp[s][1][2] < dp[s][1][0]) dp[s][1][2] = dp[s][1][0], typ3[s] = 0;


	// printf("\nok %d done / sum = %lld / deg = %d / val = %lld\n", s, sum, deg, a[s]);
	// printf("0 1: %lld / v1 = %d\n", dp[s][0][1], v1[s]);
	// printf("1 0: %lld / v1 = %d\n", dp[s][1][0], v1[s]);
	// printf("0 2: %lld / typ2 = %d / v2 = %d %d\n", dp[s][0][2], typ2[s], v2[s][0], v2[s][1]);
	// printf("1 2: %lld / typ3 = %d / v3 = %d %d %d\n", dp[s][1][2], typ3[s], v3[s][0], v3[s][1], v3[s][2]);
}

pair<int, int> track(int s, int x, int y){
	// printf(" call %d %d %d\n", s, x, y);
	if (x==0 && y==0) return {s, s};
	assert(s > 0);
	assert(dp[s][x][y] >= 0);

	if (x==1 && y==2 && typ3[s]==0) y = 0;

	if (x==0 && y==1){
		pair<int, int> ret = {s, s};

		if (v1[s] > 0){
			auto [l1, r1] = track(v1[s], 1, 0);
			nxt[s] = l1;
			ret.second = r1;

			G[s].erase(find(G[s].begin(), G[s].end(), v1[s]));
		}

		for (auto &v:G[s]){
			nxt[ret.second] = v;
			ret.second = v;
		}

		return ret;
	}

	if (x==0 && y==2){
		pair<int, int> ret = {s, s};

		if (typ2[s]==1){
			if (v2[s][0] > 0){
				auto [l1, r1] = track(v2[s][0], 1, 2);
				nxt[s] = l1;

				ret.second = r1;
			}

			return ret;
			
		}

		// else (typ2[s]==2)
		if (v2[s][0] > 0){
			auto [l1, r1] = track(v2[s][0], 1, 0);
			nxt[s] = l1;

			ret.second = r1;
			G[s].erase(find(G[s].begin(), G[s].end(), v2[s][0]));
		}

		if (v2[s][1] > 0){
			G[s].erase(find(G[s].begin(), G[s].end(), v2[s][1]));
		}

		for (auto &v:G[s]){
			nxt[ret.second] = v;
			ret.second = v;
		}

		if (v2[s][1] > 0){
			nxt[ret.second] = v2[s][1];
			ret.second = v2[s][1];

			auto [l2, r2] = track(v2[s][1], 0, 2);
			ret.second = r2;
		}

		return ret;
	}

	if (x==1 && y==0){
		pair<int, int> ret = {s, s};

		if (v1[s] > 0){
			auto [l1, r1] = track(v1[s], 0, 1);
			nxt[r1] = s;
			ret.first = l1;

			G[s].erase(find(G[s].begin(), G[s].end(), v1[s]));
		}

		for (auto &v:G[s]){
			nxt[v] = ret.first;
			ret.first = v;
		}

		assert(ret.first != s);
		return ret;
	}

	assert(x==1 && y==2);

	pair<int, int> ret = {s, s};

	if (typ3[s]==2){
		if (v3[s][0] > 0){
			int rv3 = v3[s][0], ri = 0, rj = 1;
			if (v3[s][0] > n) rv3 = v3[s][0]-n, ri = 0, rj = 0;

			auto [l1, r1] = track(rv3, ri, rj);
			ret.first = l1;
			nxt[r1] = s;

			G[s].erase(find(G[s].begin(), G[s].end(), rv3));
		}

		if (v3[s][1] > 0){
			auto [l2, r2] = track(v3[s][1], 1, 2);
			nxt[s] = l2;
			ret.second = r2;
		}

		assert(ret.first != s);
		return ret;
	}

	if (typ3[s]==1){
		assert(v3[s][1] > 0);
		G[s].erase(find(G[s].begin(), G[s].end(), v3[s][1]));

		if (v3[s][0] > 0){
			G[s].erase(find(G[s].begin(), G[s].end(), v3[s][0]));

			auto [l2, r2] = track(v3[s][0], 0, 1);
			nxt[r2] = ret.first;
			ret.first = l2;
		}

		for (auto &v:G[s]){
			nxt[v] = ret.first;
			ret.first = v;
		}



		auto [l1, r1] = track(v3[s][1], 1, 2);
		nxt[s] = l1;
		ret.second = r1;

		return ret;
	}

	//typ3[s] == 3
	if (v3[s][0] > 0){
		int rv3 = v3[s][0], ri = 0, rj = 1;
		if (v3[s][0] > n) rv3 = v3[s][0]-n, ri = 0, rj = 0;

		auto [l1, r1] = track(rv3, ri, rj);
		ret.first = l1;
		nxt[r1] = s;

		G[s].erase(find(G[s].begin(), G[s].end(), rv3));
	}

	if (v3[s][1] > 0){
		auto [l2, r2] = track(v3[s][1], 1, 0);
		nxt[s] = l2;
		ret.second = r2;

		G[s].erase(find(G[s].begin(), G[s].end(), v3[s][1]));
	}

	if (v3[s][2] > 0){
		G[s].erase(find(G[s].begin(), G[s].end(), v3[s][2]));
	}

	for (auto &v:G[s]){
		nxt[ret.second] = v;
		ret.second = v;
	}

	if (v3[s][2] > 0){
		auto [l3, r3] = track(v3[s][2], 0, 2);
		nxt[ret.second] = l3;
		ret.second = r3;
	}

	assert(ret.first != s);
	return ret;
}

ll solve2(){
	for (int i=1;i<=n;i++) nxt[i] = 0, G[i].clear();
	dfs(1);

	ll ans = dp[1][0][2];
	printf("%lld\n", ans);

	track(1, 0, 2);

	ll fuck = 0;
	vector<int> V;
	for (int i=1;i;i=nxt[i]){
		V.push_back(i);
		fuck += a[i];

		// printf(" %d %lld", i, fuck);
	} 
	// printf("\n");

	assert(ans == fuck);

	printf("%d\n", (int)V.size());
	for (auto &x:V) printf("%d ", x);
	printf("\n");

	return ans;
}

} // namespace bar

namespace baz{

void dfs1(int s, int pa);
void dfs2(int s, int pa);

vector<int> V;

void dfs1(int s, int pa){
	V.push_back(s);
	for (auto &v:adj[s]) if (v!=pa) dfs2(v, s);
}

void dfs2(int s, int pa){
	for (auto &v:adj[s]) if (v!=pa) dfs1(v, s);
	V.push_back(s);
}

void solve3(){
	dfs1(1, 0);

	ll sum = 0;
	for (int i=1;i<=n;i++) sum += a[i];

	printf("%lld\n", sum);
	printf("%d\n", n);
	for (auto &x:V) printf("%d ", x);
	printf("\n");
}

} // namespace baz

namespace naive
{
	constexpr int inf = 1e9 + 100;
    int dp[1<<20][20][2], arr[20];
    vector<int> gph[20];
    int solve()
    {
    	int N = n;
        for (int i = 0; i < N; i++) gph[i].clear();
        for (int i = 1; i <= N; i++) {
            for (int j:adj[i]) {
                gph[i-1].push_back(j-1);
            }
        }
        for (int i = 1; i <= N; i++) arr[i-1] = a[i];
        for (int i = 0; i < (1 << N); i++) {
            for (int j = 0; j < N; j++) {
                dp[i][j][0] = dp[i][j][1] = -inf;
            }
        }
        dp[1][0][0] = arr[0];
        for (int i = 1; i < (1 << N); i++) {
            for (int j = 0; j < N; j++) {
                for (int k:gph[j]) {
                    dp[i][k][1] = max(dp[i][k][1], dp[i][j][0]);
                }
            }
            for (int j = 0; j < N; j++) {
                for (int k:gph[j]) {
                    if (!(i & (1 << k))) {
                        int nxt = i | (1 << k);
                        dp[nxt][k][0] = max(dp[nxt][k][0], max(dp[i][j][0], dp[i][j][1]) + arr[k]);
                    }
                }
            }
        }
        int ans = 0;
        for (int i = 0; i < (1 << N); i++) {
            for (int j = 0; j < N; j++) {
                ans = max(ans, dp[i][j][0]);
                ans = max(ans, dp[i][j][1]);
            }
        }
        return ans;
    }
} // namespace naive

struct DSU{
	int path[100100];
	void init(int n){for (int i=1;i<=n;i++) path[i] = i;}
	int find(int s){
		if (s==path[s]) return s;
		return path[s] = find(path[s]);
	}

	int merge(int s, int v){
		s = find(s), v = find(v);
		if (s==v) return 0;
		path[v] = s;
		return 1;
	}
}dsu;

mt19937 seed(1557);
uniform_int_distribution<int> rng(0, 2147483647);
int getrand(int l, int r){return rng(seed) % (r-l+1) + l;}

void gen(){
	n = getrand(1, 12);

	dsu.init(n);
	for (int i=1;i<=n;i++) adj[i].clear();
	for (int i=1;i<=n;i++) a[i] = getrand(1, 1);

	for (int i=1;i<=n-1;i++){
		int x, y;
		do{
			x = getrand(1, n);
			y = getrand(1, n);
		}while(dsu.find(x)==dsu.find(y));

		dsu.merge(x, y);
		adj[x].push_back(y);
		adj[y].push_back(x);
	}
}

void stress(int tc){
	// printf("------------------------\n");
	// printf("Stress #%d\n", tc);

	// int ans = gen();
	gen();
	// printf("Input:\n");
	// printf("%d 2\n", n);

	// for (int i=1;i<=n;i++){
	// 	for (auto &j:adj[i]) if (j<i) printf("%d %d\n", j, i);
	// }

	// for (int i=1;i<=n;i++) printf("%lld ", a[i]);
	// printf("\n");

	// int out = solve();

	// ans = naive(ans);
	// assert(ans < INF);

	ll ans = naive::solve();
	ll out = bar::solve2();

	// printf("Answer: %lld\n", ans);
	// printf("Output: %lld\n", out);

	if (ans!=out){
		printf("------------------------\n");
		printf("Stress #%d\n", tc);

		printf("Input:\n");
		printf("%d 2\n", n);

		for (int i=1;i<=n;i++){
			for (auto &j:adj[i]) if (j<i) printf("%d %d\n", j, i);
		}

		for (int i=1;i<=n;i++) printf("%lld ", a[i]);
		printf("\n");

		printf("Answer: %lld\n", ans);
		printf("Output: %lld\n", out);
	}

	assert(ans==out);

	// if (!(out==ans && out!=-1)){
	// 	printf("------------------------\n");
	// 	printf("Stress #%d\n", tc);

	// 	printf("Input:\n");
	// 	printf("%s\n%s\n", s+1, t+1);

	// 	printf("Answer: %d\n", ans);
	// 	printf("Output: %d\n", out);
	// }

	// assert(out == ans && out != -1);

	if (tc%10000==0) printf("ok %d done\n", tc);
}

int main(){
	// for (int i=1;;i++) stress(i);

	int k;
	scanf("%d %d", &n, &k);

	for (int i=1;i<=n-1;i++){
		int x, y;
		scanf("%d %d", &x, &y);
		adj[x].push_back(y);
		adj[y].push_back(x);
	}

	for (int i=1;i<=n;i++) scanf("%lld", a+i);

	if (k==1) foo::solve1();
	else if (k==2) bar::solve2();
	else baz::solve3();
}

컴파일 시 표준 에러 (stderr) 메시지

Main.cpp: In function 'int main()':
Main.cpp:615:7: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
  615 |  scanf("%d %d", &n, &k);
      |  ~~~~~^~~~~~~~~~~~~~~~~
Main.cpp:619:8: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
  619 |   scanf("%d %d", &x, &y);
      |   ~~~~~^~~~~~~~~~~~~~~~~
Main.cpp:624:30: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
  624 |  for (int i=1;i<=n;i++) scanf("%lld", a+i);
      |                         ~~~~~^~~~~~~~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...