답안 #1081793

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
1081793 2024-08-30T10:57:45 Z vuavisao Chase (CEOI17_chase) C++14
40 / 100
259 ms 524288 KB
#include <bits/stdc++.h>
#pragma GCC optimize("O3", "unroll-loops")
using namespace std;

using ll = long long;

const int N = 100'000 + 1;

int n, breads;
int pigeons[N];
vector<int> g[N];

namespace sub1 {
	bool check() {
		return (n <= 10);
	}

	int parent[N];
	ll curPigeons[N];

	ll getDiff(int u, ll cur) {
		ll after = 0;
		while (u != 0) {
			after += curPigeons[u];
			u = parent[u];
		}
		return abs(after - cur);
	}

	ll dfs(int u, ll cur, int use) {
		ll res = 0;
		// doesn't drop the breadcrumb
		cur += curPigeons[u];
		res = max(res, getDiff(u, cur));
		for (const auto& v : g[u]) {
			if (v == parent[u]) continue;
			parent[v] = u;
			res = max(res, dfs(v, cur, use));
		}

		// drop the breadcrumb
		if (use < breads) {
			vector<ll> oldPigeons(n + 1, 0);
			oldPigeons[u] = curPigeons[u];
			for (const auto& v : g[u]) {
				oldPigeons[v] = curPigeons[v];
			}
			for (const auto& v : g[u]) {
				curPigeons[u] += curPigeons[v];
				curPigeons[v] = 0;
			}

			res = max(res, getDiff(u, cur));
			
			for (const auto& v : g[u]) {
				if (v == parent[u]) continue;
				parent[v] = u;
				res = max(res, dfs(v, cur, use + 1));
			}

			curPigeons[u] = oldPigeons[u];
			for (const auto& v : g[u]) {
				curPigeons[v] = oldPigeons[v];
			}
		}
		return res;
	}

	void solve() {
		ll res = 0;
		for (int root = 1; root <= n; ++ root) {
			for (int u = 1; u <= n; ++ u) {
				curPigeons[u] = pigeons[u];
				parent[u] = 0;
			}
			res = max(res, dfs(root, 0, 0));
		}
		cout << res;
	}
}

namespace sub2 {
	bool check() {
		return (n <= 1'000);
	}

	ll dp[N][110];

	void dfs(int u, int p) {
		ll bonus = -pigeons[p];
		for (const auto& v : g[u]) {
			bonus += pigeons[v];
		}
		for (int use = breads - 1; use >= 0; -- use) {
			dp[u][use + 1] = max(dp[u][use + 1], dp[u][use] + bonus);
		}
		for (const auto& v : g[u]) {
			if (v == p) continue;
			for (int use = 0; use <= breads; ++ use) {
				dp[v][use] = max(dp[v][use], dp[u][use]);
			}
			dfs(v, u);
		}
	}

	ll calc(int root) {
		for (int u = 0; u <= n + 1; ++ u) {
			for (int use = 0; use <= breads + 1; ++ use) {
				dp[u][use] = 0;
			}
		}
		dfs(root, 0);
		ll res = 0;
		for (int u = 1; u <= n; ++ u) {
			for (int use = 0; use <= breads; ++ use) {
				res = max(res, dp[u][use]);
			}
		}
		return res;
	}

	void solve() {
		ll res = 0;
		for (int root = 1; root <= n; ++ root) {
			res = max(res, calc(root));
		}
		cout << res;
	}
}

namespace sub4 {
	pair<ll, int> dpOut[N][101][2];
	pair<ll, int> dpIn[N][101][2];
	ll cost[N];
	ll res = 0;

	ll getCost(int u, int p) {
		return cost[u] - pigeons[p];
	}

	void update(pair<ll, int> cur[], pair<ll, int> val) {
		pair<ll, int> old = cur[0];
		if (val.first > cur[0].first) {
			cur[0] = val;
			cur[1] = max(cur[1], old);
		}
		else {
			cur[1] = max(cur[1], val);
		}
	}

	void dfsInOut(int u, int p) {
		// out is u can accept len >= 1
		for (int use = 0; use <= breads; ++ use) {
			dpOut[u][use][0] = (use == 0 ? make_pair(0ll, u) : make_pair(getCost(u, 0), u));
		}
		// in is u can accept len >= 2
		for (int use = 0; use <= breads; ++ use) {
			dpIn[u][use][0] = (use == 0 ? make_pair(0ll, -1) : make_pair(getCost(u, p), u));
		}

		for (const auto& v : g[u]) {
			if (v == p) continue;
			dfsInOut(v, u);

			// case out:
			// prefix u = v
			for (int use = 0; use <= breads; ++ use) {
				pair<ll, int> cur = make_pair(dpOut[v][use][0].first, v);
				update(dpOut[u][use], cur);
				if (use < breads) {
					cur.first += getCost(u, v);
					update(dpOut[u][use + 1], cur);
				}
			}

			// case in:
			// prefix v = u
			for (int use = 0; use <= breads; ++ use) {
				pair<ll, int> cur = make_pair(dpIn[v][use][0].first, v);
				update(dpIn[u][use], cur);
				if (use < breads) {
					cur.first += getCost(u, p);
					update(dpIn[u][use + 1], cur);
				}
			}
		}

		for (const auto& v : g[u]) {
			if (v == p) continue;
			
			for (int useOut = 0; useOut <= breads; ++ useOut) {
				int useIn = breads - useOut;
				for (int typeOut = 0; typeOut < 2; ++ typeOut) {
					for (int typeIn = 0; typeIn < 2; ++ typeIn) {
						if (dpOut[u][useOut][typeOut].second == v) continue;
						res = max(res, dpOut[u][useOut][typeOut].first + dpIn[v][useIn][typeIn].first);
					}
				}
			}
		}
		for (int useOut = 0; useOut <= breads; ++ useOut) {
			res = max(res, dpOut[u][useOut][0].first);
		}
	}

	void solve() {
		for (int u = 1; u <= n; ++ u) {
			for (const auto& v : g[u]) {
				cost[v] += pigeons[u];
			}
		}
		dfsInOut(1, 0);
		cout << res;
	}

}

int32_t main() {
	std::ios_base::sync_with_stdio(false);
	std::cin.tie(nullptr); std::cout.tie(nullptr);
	cin >> n >> breads;
	for (int u = 1; u <= n; ++ u) cin >> pigeons[u];
	for (int i = 2; i <= n; ++ i) {
		int u, v; cin >> u >> v;
		g[u].push_back(v);
		g[v].push_back(u);
	}
	// if (sub1::check()) {
	// 	sub1::solve();
	// 	return 0;
	// }
	// if (sub2::check()) {
	// 	sub2::solve();
	// 	return 0;
	// }
	// cout << sub2::calc(1);
	sub4::solve();
	return (0 ^ 0);
}

/// Code by vuavisao
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 2648 KB Output is correct
2 Correct 1 ms 2652 KB Output is correct
3 Correct 1 ms 2652 KB Output is correct
4 Correct 1 ms 2824 KB Output is correct
5 Correct 1 ms 2652 KB Output is correct
6 Correct 1 ms 2652 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 2648 KB Output is correct
2 Correct 1 ms 2652 KB Output is correct
3 Correct 1 ms 2652 KB Output is correct
4 Correct 1 ms 2824 KB Output is correct
5 Correct 1 ms 2652 KB Output is correct
6 Correct 1 ms 2652 KB Output is correct
7 Correct 5 ms 9048 KB Output is correct
8 Correct 4 ms 9052 KB Output is correct
9 Correct 5 ms 9052 KB Output is correct
10 Correct 6 ms 9052 KB Output is correct
11 Correct 4 ms 9052 KB Output is correct
12 Correct 4 ms 8972 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Runtime error 259 ms 524288 KB Execution killed with signal 9
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 2648 KB Output is correct
2 Correct 1 ms 2652 KB Output is correct
3 Correct 1 ms 2652 KB Output is correct
4 Correct 1 ms 2824 KB Output is correct
5 Correct 1 ms 2652 KB Output is correct
6 Correct 1 ms 2652 KB Output is correct
7 Correct 5 ms 9048 KB Output is correct
8 Correct 4 ms 9052 KB Output is correct
9 Correct 5 ms 9052 KB Output is correct
10 Correct 6 ms 9052 KB Output is correct
11 Correct 4 ms 9052 KB Output is correct
12 Correct 4 ms 8972 KB Output is correct
13 Runtime error 259 ms 524288 KB Execution killed with signal 9
14 Halted 0 ms 0 KB -