답안 #1081798

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
1081798 2024-08-30T11:04:35 Z vuavisao Chase (CEOI17_chase) C++14
70 / 100
332 ms 408528 KB
#include <bits/stdc++.h>
#pragma GCC optimize("O3", "unroll-loops")
using namespace std;

using ll = long long;

const int N = 100'000 + 1;

int n, breads;
int pigeons[N];
vector<int> g[N];

namespace sub1 {
	bool check() {
		return (n <= 10);
	}

	int parent[N];
	ll curPigeons[N];

	ll getDiff(int u, ll cur) {
		ll after = 0;
		while (u != 0) {
			after += curPigeons[u];
			u = parent[u];
		}
		return abs(after - cur);
	}

	ll dfs(int u, ll cur, int use) {
		ll res = 0;
		// doesn't drop the breadcrumb
		cur += curPigeons[u];
		res = max(res, getDiff(u, cur));
		for (const auto& v : g[u]) {
			if (v == parent[u]) continue;
			parent[v] = u;
			res = max(res, dfs(v, cur, use));
		}

		// drop the breadcrumb
		if (use < breads) {
			vector<ll> oldPigeons(n + 1, 0);
			oldPigeons[u] = curPigeons[u];
			for (const auto& v : g[u]) {
				oldPigeons[v] = curPigeons[v];
			}
			for (const auto& v : g[u]) {
				curPigeons[u] += curPigeons[v];
				curPigeons[v] = 0;
			}

			res = max(res, getDiff(u, cur));
			
			for (const auto& v : g[u]) {
				if (v == parent[u]) continue;
				parent[v] = u;
				res = max(res, dfs(v, cur, use + 1));
			}

			curPigeons[u] = oldPigeons[u];
			for (const auto& v : g[u]) {
				curPigeons[v] = oldPigeons[v];
			}
		}
		return res;
	}

	void solve() {
		ll res = 0;
		for (int root = 1; root <= n; ++ root) {
			for (int u = 1; u <= n; ++ u) {
				curPigeons[u] = pigeons[u];
				parent[u] = 0;
			}
			res = max(res, dfs(root, 0, 0));
		}
		cout << res;
	}
}

namespace sub2 {
	bool check() {
		return (n <= 1'000);
	}

	ll dp[N][110];

	void dfs(int u, int p) {
		ll bonus = -pigeons[p];
		for (const auto& v : g[u]) {
			bonus += pigeons[v];
		}
		for (int use = breads - 1; use >= 0; -- use) {
			dp[u][use + 1] = max(dp[u][use + 1], dp[u][use] + bonus);
		}
		for (const auto& v : g[u]) {
			if (v == p) continue;
			for (int use = 0; use <= breads; ++ use) {
				dp[v][use] = max(dp[v][use], dp[u][use]);
			}
			dfs(v, u);
		}
	}

	ll calc(int root) {
		for (int u = 0; u <= n + 1; ++ u) {
			for (int use = 0; use <= breads + 1; ++ use) {
				dp[u][use] = 0;
			}
		}
		dfs(root, 0);
		ll res = 0;
		for (int u = 1; u <= n; ++ u) {
			for (int use = 0; use <= breads; ++ use) {
				res = max(res, dp[u][use]);
			}
		}
		return res;
	}

	void solve() {
		ll res = 0;
		for (int root = 1; root <= n; ++ root) {
			res = max(res, calc(root));
		}
		cout << res;
	}
}

namespace sub4 {
	pair<ll, int> dpOut[N][101][2];
	ll dpIn[N][101];
	ll cost[N];
	ll res = 0;

	ll getCost(int u, int p) {
		return cost[u] - pigeons[p];
	}

	void update(pair<ll, int> cur[], pair<ll, int> val) {
		pair<ll, int> old = cur[0];
		if (val.first > cur[0].first) {
			cur[0] = val;
			cur[1] = max(cur[1], old);
		}
		else {
			cur[1] = max(cur[1], val);
		}
	}

	void dfsInOut(int u, int p) {
		// out is u can accept len >= 1
		for (int use = 0; use <= breads; ++ use) {
			dpOut[u][use][0] = (use == 0 ? make_pair(0ll, u) : make_pair(getCost(u, 0), u));
		}
		// in is u can accept len >= 2
		for (int use = 0; use <= breads; ++ use) {
			dpIn[u][use] = (use == 0 ? 0ll : getCost(u, p));
		}

		for (const auto& v : g[u]) {
			if (v == p) continue;
			dfsInOut(v, u);

			// case out:
			// prefix u = v
			for (int use = 0; use <= breads; ++ use) {
				pair<ll, int> cur = make_pair(dpOut[v][use][0].first, v);
				update(dpOut[u][use], cur);
				if (use < breads) {
					cur.first += getCost(u, v);
					update(dpOut[u][use + 1], cur);
				}
			}

			// case in:
			// prefix v = u
			for (int use = 0; use <= breads; ++ use) {
				auto cur = dpIn[v][use];
				dpIn[u][use] = max(dpIn[u][use], cur);
				if (use < breads) {
					cur += getCost(u, p);
					dpIn[u][use + 1] = max(dpIn[u][use + 1], cur);
				}
			}
		}

		for (const auto& v : g[u]) {
			if (v == p) continue;
			
			for (int useOut = 0; useOut <= breads; ++ useOut) {
				int useIn = breads - useOut;
				for (int typeOut = 0; typeOut < 2; ++ typeOut) {
					if (dpOut[u][useOut][typeOut].second == v) continue;
					res = max(res, dpOut[u][useOut][typeOut].first + dpIn[v][useIn]);
				}
			}
		}
		for (int useOut = 0; useOut <= breads; ++ useOut) {
			res = max(res, dpOut[u][useOut][0].first);
		}
		for (int useIn = 0; useIn <= breads; ++ useIn) {
			res = max(res, dpIn[u][useIn]);
		}
	}

	void solve() {
		for (int u = 1; u <= n; ++ u) {
			for (const auto& v : g[u]) {
				cost[v] += pigeons[u];
			}
		}
		dfsInOut(1, 0);
		cout << res;
	}

}

int32_t main() {
	std::ios_base::sync_with_stdio(false);
	std::cin.tie(nullptr); std::cout.tie(nullptr);
	cin >> n >> breads;
	for (int u = 1; u <= n; ++ u) cin >> pigeons[u];
	for (int i = 2; i <= n; ++ i) {
		int u, v; cin >> u >> v;
		g[u].push_back(v);
		g[v].push_back(u);
	}
	// if (sub1::check()) {
	// 	sub1::solve();
	// 	return 0;
	// }
	// if (sub2::check()) {
	// 	sub2::solve();
	// 	return 0;
	// }
	// cout << sub2::calc(1);
	sub4::solve();
	return (0 ^ 0);
}

/// Code by vuavisao
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 2648 KB Output is correct
2 Correct 1 ms 2648 KB Output is correct
3 Correct 2 ms 2652 KB Output is correct
4 Correct 1 ms 2652 KB Output is correct
5 Correct 1 ms 2652 KB Output is correct
6 Correct 1 ms 2652 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 2648 KB Output is correct
2 Correct 1 ms 2648 KB Output is correct
3 Correct 2 ms 2652 KB Output is correct
4 Correct 1 ms 2652 KB Output is correct
5 Correct 1 ms 2652 KB Output is correct
6 Correct 1 ms 2652 KB Output is correct
7 Correct 6 ms 6748 KB Output is correct
8 Correct 3 ms 6748 KB Output is correct
9 Correct 3 ms 6748 KB Output is correct
10 Correct 3 ms 6748 KB Output is correct
11 Correct 3 ms 6748 KB Output is correct
12 Correct 3 ms 6748 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 320 ms 408528 KB Output is correct
2 Correct 298 ms 408404 KB Output is correct
3 Correct 244 ms 402636 KB Output is correct
4 Correct 181 ms 402372 KB Output is correct
5 Correct 320 ms 402260 KB Output is correct
6 Correct 332 ms 402260 KB Output is correct
7 Correct 314 ms 402496 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 2648 KB Output is correct
2 Correct 1 ms 2648 KB Output is correct
3 Correct 2 ms 2652 KB Output is correct
4 Correct 1 ms 2652 KB Output is correct
5 Correct 1 ms 2652 KB Output is correct
6 Correct 1 ms 2652 KB Output is correct
7 Correct 6 ms 6748 KB Output is correct
8 Correct 3 ms 6748 KB Output is correct
9 Correct 3 ms 6748 KB Output is correct
10 Correct 3 ms 6748 KB Output is correct
11 Correct 3 ms 6748 KB Output is correct
12 Correct 3 ms 6748 KB Output is correct
13 Correct 320 ms 408528 KB Output is correct
14 Correct 298 ms 408404 KB Output is correct
15 Correct 244 ms 402636 KB Output is correct
16 Correct 181 ms 402372 KB Output is correct
17 Correct 320 ms 402260 KB Output is correct
18 Correct 332 ms 402260 KB Output is correct
19 Correct 314 ms 402496 KB Output is correct
20 Correct 326 ms 402504 KB Output is correct
21 Correct 171 ms 402512 KB Output is correct
22 Incorrect 327 ms 402512 KB Output isn't correct
23 Halted 0 ms 0 KB -