Submission #794173

# Submission time Handle Problem Language Result Execution time Memory
794173 2023-07-26T10:17:04 Z flappybird Travelling Trader (CCO23_day2problem2) C++17
0 / 25
3 ms 5076 KB
#include <bits/stdc++.h>
#include <cassert>
#pragma GCC optimize("O3")
#pragma GCC optimize("Ofast")
#pragma GCC optimize("unroll-loops")
#pragma GCC target("avx,avx2,fma")
using namespace std;
typedef long long ll;
typedef pair<ll, ll> pll;
typedef pair<int, int> pii;
#define MAX 202300
#define MAXS 20
#define INF 100000000000000001
#define bb ' '
#define ln '\n'
#define Ln '\n'
int C[MAX];
vector<int> adj[MAX];
int N, K;
namespace k1 {
	ll sum[MAX];
	int mv[MAX];
	void dfs(int x, int p = 0) {
		sum[x] = C[x];
		for (auto v : adj[x]) if (v != p) {
			dfs(v, x);
			if (sum[v] > sum[mv[x]]) mv[x] = v;
		}
		sum[x] += sum[mv[x]];
	}
	void solve() {
		dfs(1);
		vector<int> ansv;
		int v = 1;
		ll ans = 0;
		while (1) {
			ansv.push_back(v);
			ans += C[v];
			v = mv[v];
			if (!v) break;
		}
		cout << ans << ln;
		cout << ansv.size() << ln;
		for (auto v : ansv) cout << v << bb;
	}
}
namespace k2 {
	typedef pair<ll, int> pli;
	ll dp[MAX];
	ll end[2][MAX];
	int dpath[MAX]; // dp path
	pii epath[MAX]; // end path
	int chk[MAX];
	int e1path[MAX][3];
	// 0 : child -> calc(c) -> another calc(c) -> down(c, 0)
	// 1 : child -> calc(c) -> down(c, 1)
	const int DEBUG = 1;
	int sp[MAX][MAXS];
	int dep[MAX] = { 0, 1 };
	void dfs(int x, int p = 0) {
		if (DEBUG) {
			sp[x][0] = p;
			int i;
			for (i = 1; i < MAXS; i++) sp[x][i] = sp[sp[x][i - 1]][i - 1];
		}
		pli me[3]; //max end
		pli me1[3]; //max end
		pli md[3]; //max dp
		int i, j, k;
		for (i = 0; i < 3; i++) me1[i] = me[i] = md[i] = pli(-INF, -1);
		dp[x] += C[x];
		end[0][x] += C[x];
		end[1][x] += C[x];
		if (p && adj[x].size() == 1) {
			end[1][x] = -INF;
			return;
		}
		int pv = 0;
		for (auto v : adj[x]) if (v != p) {
			pv = v;
			if (DEBUG) dep[v] = dep[x] + 1;
			dfs(v, x);
			dp[x] += C[v];
			end[0][x] += C[v];
			end[1][x] += C[v];
			pli d = pli(dp[v] - C[v], v);
			pli e = pli(end[0][v] - C[v], v);
			pli e1 = pli(end[1][v] - C[v], v);
			me[2] = max(me[2], e);
			md[2] = max(md[2], d);
			me1[2] = max(me1[2], e1);
			for (i = 2; i >= 1; i--) if (me[i] > me[i - 1]) swap(me[i], me[i - 1]);
			for (i = 2; i >= 1; i--) if (me1[i] > me1[i - 1]) swap(me1[i], me1[i - 1]);
			for (i = 2; i >= 1; i--) if (md[i] > md[i - 1]) swap(md[i], md[i - 1]);
		}
		dpath[x] = md[0].second;
		dp[x] += md[0].first;
		epath[x].second = pv;
		int c = 0;
		if (p && adj[x].size() <= 2) c = 1;
		if (!p && adj[x].size() == 1) c = 1;
		if (c) {
			end[0][x] += me[0].first;
			if (end[0][x] < end[1][pv] + C[x]) {
				epath[x] = pii(-1, pv);
				end[0][x] = end[1][pv] + C[x];
			}
			end[1][x] = -INF; //-------------------------------------------
			return;
		}
		ll mx = -INF;
		if (md[0].second != me[0].second) {
			mx = md[0].first + me[0].first;
			epath[x] = pii(md[0].second, me[0].second);
		}
		else {
			if (mx < md[0].first + me[1].first) {
				mx = md[0].first + me[1].first;
				epath[x] = pii(md[0].second, me[1].second);
				assert(md[0].second != me[1].second);
			}
			if (mx < md[1].first + me[0].first) {
				mx = md[1].first + me[0].first;
				epath[x] = pii(md[1].second, me[0].second);
				assert(md[1].second != me[0].second);
			}
		}
		if (mx < me1[0].first) {
			mx = me1[0].first;
			epath[x] = pii(-1, me1[0].second);
		}
		end[0][x] += mx;
		mx = -INF;
		for (i = 0; i < 3; i++) for (j = i + 1; j < 3; j++) {
			if (!~me[i].second) continue;
			if (!~me[j].second) continue;
			for (k = 0; k < 3; k++) {
				if (!~me[k].second) continue;
				if (me[k].second == md[i].second) continue;
				if (me[k].second == md[j].second) continue;
				ll sum = me[k].first + md[i].first + md[j].first;
				if (mx < sum) {
					mx = sum;
					e1path[x][0] = md[i].second;
					e1path[x][1] = md[j].second;
					e1path[x][2] = me[k].second;
				}
			}
		}
		for (i = 0; i < 2; i++) {
			if (!~md[i].second) continue;
			for (j = 0; j < 2; j++) {
				if (!~me1[j].second) continue;
				if (md[i].second == me1[j].second) continue;
				ll sum = md[i].first + me1[j].first;
				if (mx < sum) {
					mx = sum;
					chk[x] = 1;
					e1path[x][0] = md[i].second;
					e1path[x][1] = me1[j].second;
				}
			}
		}
		end[1][x] += mx;
	}
	inline int lca(int u, int v) {
		int i;
		if (dep[u] != dep[v]) {
			if (dep[u] > dep[v]) swap(u, v);
			int d = dep[v] - dep[u];
			for (i = 0; i < MAXS; i++) if (d >> i & 1) v = sp[v][i];
		}
		if (u == v) return u;
		for (i = MAXS - 1; i >= 0; i--) if (sp[u][i] != sp[v][i]) u = sp[u][i], v = sp[v][i];
		return sp[u][0];
	}
	int dis(int u, int v) {
		return dep[u] + dep[v] - 2 * dep[lca(u, v)];
	}
	vector<int> ansv;
	void calc(int x, int c, int p = 0) {
		if (adj[x].size() == 1) {
			ansv.push_back(x);
			return;
		}
		if (!c) ansv.push_back(x), calc(dpath[x], c ^ 1, x);
		for (auto v : adj[x]) if (v != p && dpath[x] != v) ansv.push_back(v);
		if (c) calc(dpath[x], c ^ 1, x), ansv.push_back(x);
	}
	void down(int x, int c, int p = 0) {
		if (!c) {
			ansv.push_back(x);
			if (p && adj[x].size() == 1) return;
			if (!~epath[x].first) {
				down(epath[x].second, 1, x);
				return;
			}
			if (epath[x].first) calc(epath[x].first, 1, x);
			for (auto v : adj[x]) if (v != p) {
				if (v == epath[x].first) continue;
				if (v == epath[x].second) continue;
				ansv.push_back(v);
			}
			down(epath[x].second, 0, x);
		}
		else {
			assert(adj[x].size() > 2);
			for (auto v : adj[x]) if (v != p) {
				if (v == e1path[x][0]) continue;
				if (v == e1path[x][1]) continue;
				if (v == e1path[x][2]) continue;
				ansv.push_back(v);
			}
			calc(e1path[x][0], 0, x);
			ansv.push_back(x);
			if (chk[x]) down(e1path[x][1], 1, x);
			else {
				calc(e1path[x][1], 1, x);
				down(e1path[x][2], 0, x);
			}
		}
	}
	void solve() {
		dfs(1);
		cout << end[0][1] << ln;
		down(1, 0);
		cout << ansv.size() << ln;
		ll sum = 0;
		for (auto v : ansv) cout << v << bb, sum += C[v];
		int i;
		for (i = 1; i < ansv.size(); i++) {
			cout << i << ln;
			assert(dis(ansv[i], ansv[i - 1]) <= 2);
		}
		vector<int> cpy = ansv;
		sort(cpy.begin(), cpy.end());
		cpy.erase(unique(cpy.begin(), cpy.end()), cpy.end());
		assert(cpy.size() == ansv.size());
		//assert(sum == end[0][1]);
		cout << sum << Ln;
	}
}
namespace k3 {
	vector<int> ansv;
	void dfs(int x, int c, int p = 0) {
		if (c) ansv.push_back(x);
		for (auto v : adj[x]) if (v != p) dfs(v, c ^ 1, x);
		if (!c) ansv.push_back(x);
	}
	void solve() {
		ll sum = 0;
		int i;
		for (i = 1; i <= N; i++) sum += C[i];
		dfs(1, 1);
		cout << sum << ln;
		cout << N << Ln;
		for (auto v : ansv) cout << v << bb;
	}
}
signed main() {
	ios::sync_with_stdio(false), cin.tie(0);
	cin >> N >> K;
	int i, a, b;
	for (i = 1; i < N; i++) {
		cin >> a >> b;
		adj[a].push_back(b);
		adj[b].push_back(a);
	}
	for (i = 1; i <= N; i++) cin >> C[i];
	//if (K == 1) k1::solve();
	if (K == 2) k2::solve();
	//if (K == 3) k3::solve();
}

Compilation message

Main.cpp: In function 'void k2::solve()':
Main.cpp:231:17: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  231 |   for (i = 1; i < ansv.size(); i++) {
      |               ~~^~~~~~~~~~~~~
# Verdict Execution time Memory Grader output
1 Incorrect 3 ms 5076 KB Unexpected end of file - int64 expected
2 Halted 0 ms 0 KB -
# Verdict Execution time Memory Grader output
1 Incorrect 3 ms 5076 KB Extra information in the output file
2 Halted 0 ms 0 KB -
# Verdict Execution time Memory Grader output
1 Incorrect 3 ms 5076 KB Extra information in the output file
2 Halted 0 ms 0 KB -
# Verdict Execution time Memory Grader output
1 Incorrect 3 ms 5076 KB Extra information in the output file
2 Halted 0 ms 0 KB -
# Verdict Execution time Memory Grader output
1 Incorrect 3 ms 5076 KB Unexpected end of file - int64 expected
2 Halted 0 ms 0 KB -
# Verdict Execution time Memory Grader output
1 Incorrect 3 ms 5076 KB Unexpected end of file - int64 expected
2 Halted 0 ms 0 KB -