답안 #1081799

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
1081799 2024-08-30T11:05:42 Z vuavisao Chase (CEOI17_chase) C++14
70 / 100
338 ms 444500 KB
#include <bits/stdc++.h>
#pragma GCC optimize("O3", "unroll-loops")
using namespace std;

using ll = long long;

const int N = 100'000 + 10;

int n, breads;
int pigeons[N];
vector<int> g[N];

namespace sub1 {
	bool check() {
		return (n <= 10);
	}

	int parent[N];
	ll curPigeons[N];

	ll getDiff(int u, ll cur) {
		ll after = 0;
		while (u != 0) {
			after += curPigeons[u];
			u = parent[u];
		}
		return abs(after - cur);
	}

	ll dfs(int u, ll cur, int use) {
		ll res = 0;
		// doesn't drop the breadcrumb
		cur += curPigeons[u];
		res = max(res, getDiff(u, cur));
		for (const auto& v : g[u]) {
			if (v == parent[u]) continue;
			parent[v] = u;
			res = max(res, dfs(v, cur, use));
		}

		// drop the breadcrumb
		if (use < breads) {
			vector<ll> oldPigeons(n + 1, 0);
			oldPigeons[u] = curPigeons[u];
			for (const auto& v : g[u]) {
				oldPigeons[v] = curPigeons[v];
			}
			for (const auto& v : g[u]) {
				curPigeons[u] += curPigeons[v];
				curPigeons[v] = 0;
			}

			res = max(res, getDiff(u, cur));
			
			for (const auto& v : g[u]) {
				if (v == parent[u]) continue;
				parent[v] = u;
				res = max(res, dfs(v, cur, use + 1));
			}

			curPigeons[u] = oldPigeons[u];
			for (const auto& v : g[u]) {
				curPigeons[v] = oldPigeons[v];
			}
		}
		return res;
	}

	void solve() {
		ll res = 0;
		for (int root = 1; root <= n; ++ root) {
			for (int u = 1; u <= n; ++ u) {
				curPigeons[u] = pigeons[u];
				parent[u] = 0;
			}
			res = max(res, dfs(root, 0, 0));
		}
		cout << res;
	}
}

namespace sub2 {
	bool check() {
		return (n <= 1'000);
	}

	ll dp[N][110];

	void dfs(int u, int p) {
		ll bonus = -pigeons[p];
		for (const auto& v : g[u]) {
			bonus += pigeons[v];
		}
		for (int use = breads - 1; use >= 0; -- use) {
			dp[u][use + 1] = max(dp[u][use + 1], dp[u][use] + bonus);
		}
		for (const auto& v : g[u]) {
			if (v == p) continue;
			for (int use = 0; use <= breads; ++ use) {
				dp[v][use] = max(dp[v][use], dp[u][use]);
			}
			dfs(v, u);
		}
	}

	ll calc(int root) {
		for (int u = 0; u <= n + 1; ++ u) {
			for (int use = 0; use <= breads + 1; ++ use) {
				dp[u][use] = 0;
			}
		}
		dfs(root, 0);
		ll res = 0;
		for (int u = 1; u <= n; ++ u) {
			for (int use = 0; use <= breads; ++ use) {
				res = max(res, dp[u][use]);
			}
		}
		return res;
	}

	void solve() {
		ll res = 0;
		for (int root = 1; root <= n; ++ root) {
			res = max(res, calc(root));
		}
		cout << res;
	}
}

namespace sub4 {
	pair<ll, int> dpOut[N][110][2];
	ll dpIn[N][110];
	ll cost[N];
	ll res = 0;

	ll getCost(int u, int p) {
		return cost[u] - pigeons[p];
	}

	void update(pair<ll, int> cur[], pair<ll, int> val) {
		pair<ll, int> old = cur[0];
		if (val.first > cur[0].first) {
			cur[0] = val;
			cur[1] = max(cur[1], old);
		}
		else {
			cur[1] = max(cur[1], val);
		}
	}

	void dfsInOut(int u, int p) {
		// out is u can accept len >= 1
		for (int use = 0; use <= breads; ++ use) {
			dpOut[u][use][0] = (use == 0 ? make_pair(0ll, u) : make_pair(getCost(u, 0), u));
		}
		// in is u can accept len >= 2
		for (int use = 0; use <= breads; ++ use) {
			dpIn[u][use] = (use == 0 ? 0ll : getCost(u, p));
		}

		for (const auto& v : g[u]) {
			if (v == p) continue;
			dfsInOut(v, u);

			// case out:
			// prefix u = v
			for (int use = 0; use <= breads; ++ use) {
				pair<ll, int> cur = make_pair(dpOut[v][use][0].first, v);
				update(dpOut[u][use], cur);
				if (use < breads) {
					cur.first += getCost(u, v);
					update(dpOut[u][use + 1], cur);
				}
			}

			// case in:
			// prefix v = u
			for (int use = 0; use <= breads; ++ use) {
				auto cur = dpIn[v][use];
				dpIn[u][use] = max(dpIn[u][use], cur);
				if (use < breads) {
					cur += getCost(u, p);
					dpIn[u][use + 1] = max(dpIn[u][use + 1], cur);
				}
			}
		}

		for (const auto& v : g[u]) {
			if (v == p) continue;
			
			for (int useOut = 0; useOut <= breads; ++ useOut) {
				int useIn = breads - useOut;
				for (int typeOut = 0; typeOut < 2; ++ typeOut) {
					if (dpOut[u][useOut][typeOut].second == v) continue;
					res = max(res, dpOut[u][useOut][typeOut].first + dpIn[v][useIn]);
				}
			}
		}
		for (int useOut = 0; useOut <= breads; ++ useOut) {
			res = max(res, dpOut[u][useOut][0].first);
		}
		for (int useIn = 0; useIn <= breads; ++ useIn) {
			res = max(res, dpIn[u][useIn]);
		}
	}

	void solve() {
		for (int u = 1; u <= n; ++ u) {
			for (const auto& v : g[u]) {
				cost[v] += pigeons[u];
			}
		}
		dfsInOut(1, 0);
		cout << res;
	}

}

int32_t main() {
	std::ios_base::sync_with_stdio(false);
	std::cin.tie(nullptr); std::cout.tie(nullptr);
	cin >> n >> breads;
	for (int u = 1; u <= n; ++ u) cin >> pigeons[u];
	for (int i = 2; i <= n; ++ i) {
		int u, v; cin >> u >> v;
		g[u].push_back(v);
		g[v].push_back(u);
	}
	// if (sub1::check()) {
	// 	sub1::solve();
	// 	return 0;
	// }
	// if (sub2::check()) {
	// 	sub2::solve();
	// 	return 0;
	// }
	// cout << sub2::calc(1);
	sub4::solve();
	return (0 ^ 0);
}

/// Code by vuavisao
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 2652 KB Output is correct
2 Correct 1 ms 2652 KB Output is correct
3 Correct 1 ms 2652 KB Output is correct
4 Correct 1 ms 2652 KB Output is correct
5 Correct 1 ms 2652 KB Output is correct
6 Correct 1 ms 2652 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 2652 KB Output is correct
2 Correct 1 ms 2652 KB Output is correct
3 Correct 1 ms 2652 KB Output is correct
4 Correct 1 ms 2652 KB Output is correct
5 Correct 1 ms 2652 KB Output is correct
6 Correct 1 ms 2652 KB Output is correct
7 Correct 4 ms 7004 KB Output is correct
8 Correct 3 ms 7004 KB Output is correct
9 Correct 3 ms 6948 KB Output is correct
10 Correct 4 ms 7004 KB Output is correct
11 Correct 4 ms 7000 KB Output is correct
12 Correct 3 ms 7000 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 315 ms 444500 KB Output is correct
2 Correct 305 ms 444464 KB Output is correct
3 Correct 255 ms 437932 KB Output is correct
4 Correct 194 ms 437588 KB Output is correct
5 Correct 337 ms 437584 KB Output is correct
6 Correct 338 ms 437584 KB Output is correct
7 Correct 316 ms 437580 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 2652 KB Output is correct
2 Correct 1 ms 2652 KB Output is correct
3 Correct 1 ms 2652 KB Output is correct
4 Correct 1 ms 2652 KB Output is correct
5 Correct 1 ms 2652 KB Output is correct
6 Correct 1 ms 2652 KB Output is correct
7 Correct 4 ms 7004 KB Output is correct
8 Correct 3 ms 7004 KB Output is correct
9 Correct 3 ms 6948 KB Output is correct
10 Correct 4 ms 7004 KB Output is correct
11 Correct 4 ms 7000 KB Output is correct
12 Correct 3 ms 7000 KB Output is correct
13 Correct 315 ms 444500 KB Output is correct
14 Correct 305 ms 444464 KB Output is correct
15 Correct 255 ms 437932 KB Output is correct
16 Correct 194 ms 437588 KB Output is correct
17 Correct 337 ms 437584 KB Output is correct
18 Correct 338 ms 437584 KB Output is correct
19 Correct 316 ms 437580 KB Output is correct
20 Correct 319 ms 437584 KB Output is correct
21 Correct 174 ms 437708 KB Output is correct
22 Incorrect 325 ms 437584 KB Output isn't correct
23 Halted 0 ms 0 KB -