Submission #1081819

#TimeUsernameProblemLanguageResultExecution timeMemory
1081819vuavisaoChase (CEOI17_chase)C++14
100 / 100
315 ms407888 KiB
#include <bits/stdc++.h>
using namespace std;

const int N = 100'000 + 1;

int n, breads;
int pigeons[N];
vector<int> g[N];

namespace sub1 {
	bool check() {
		return (n <= 10);
	}

	int parent[N];
	long long curPigeons[N];

	long long getDiff(int u, long long cur) {
		long long after = 0;
		while (u != 0) {
			after += curPigeons[u];
			u = parent[u];
		}
		return abs(after - cur);
	}

	long long dfs(int u, long long cur, int use) {
		long long res = 0;

		cur += curPigeons[u];
		res = max(res, getDiff(u, cur));
		for (const auto& v : g[u]) {
			if (v == parent[u]) continue;
			parent[v] = u;
			res = max(res, dfs(v, cur, use));
		}

		if (use < breads) {
			vector<long long> oldPigeons(n + 1, 0);
			oldPigeons[u] = curPigeons[u];
			for (const auto& v : g[u]) {
				oldPigeons[v] = curPigeons[v];
			}
			for (const auto& v : g[u]) {
				curPigeons[u] += curPigeons[v];
				curPigeons[v] = 0;
			}

			res = max(res, getDiff(u, cur));
			
			for (const auto& v : g[u]) {
				if (v == parent[u]) continue;
				parent[v] = u;
				res = max(res, dfs(v, cur, use + 1));
			}

			curPigeons[u] = oldPigeons[u];
			for (const auto& v : g[u]) {
				curPigeons[v] = oldPigeons[v];
			}
		}
		return res;
	}

	void solve() {
		long long res = 0;
		for (int root = 1; root <= n; ++ root) {
			for (int u = 1; u <= n; ++ u) {
				curPigeons[u] = pigeons[u];
				parent[u] = 0;
			}
			res = max(res, dfs(root, 0, 0));
		}
		cout << res;
	}
}

namespace sub2 {
	bool check() {
		return (n <= 1'000);
	}

	long long dp[N][110];

	void dfs(int u, int p) {
		long long bonus = -pigeons[p];
		for (const auto& v : g[u]) {
			bonus += pigeons[v];
		}
		for (int use = breads - 1; use >= 0; -- use) {
			dp[u][use + 1] = max(dp[u][use + 1], dp[u][use] + bonus);
		}
		for (const auto& v : g[u]) {
			if (v == p) continue;
			for (int use = 0; use <= breads; ++ use) {
				dp[v][use] = max(dp[v][use], dp[u][use]);
			}
			dfs(v, u);
		}
	}

	long long calc(int root) {
		for (int u = 0; u <= n + 1; ++ u) {
			for (int use = 0; use <= breads + 1; ++ use) {
				dp[u][use] = 0;
			}
		}
		dfs(root, 0);
		long long res = 0;
		for (int u = 1; u <= n; ++ u) {
			for (int use = 0; use <= breads; ++ use) {
				res = max(res, dp[u][use]);
			}
		}
		return res;
	}

	void solve() {
		long long res = 0;
		for (int root = 1; root <= n; ++ root) {
			res = max(res, calc(root));
		}
		cout << res;
	}
}

namespace sub4 {
	pair<long long, int> dpOut[N][101][2];
	long long dpIn[N][101];
	long long cost[N];
	long long res = 0;

	long long getCost(int u, int p) {
		return cost[u] - pigeons[p];
	}

	void update(pair<long long, int> cur[], pair<long long, int> val) {
		pair<long long, int> old = cur[0];
		if (val.first > cur[0].first) {
			cur[0] = val;
			cur[1] = old;
		}
		else {
			cur[1] = max(cur[1], val);
		}
	}

	void dfsInOut(int u, int p) {
		for (int use = 0; use <= breads; ++ use) {
			dpOut[u][use][0] = (use == 0 ? make_pair(0ll, u) : make_pair(getCost(u, 0), u));
		}
		for (int use = 0; use <= breads; ++ use) {
			dpIn[u][use] = (use == 0 ? 0ll : getCost(u, p));
		}

		for (const auto& v : g[u]) {
			if (v == p) continue;
			dfsInOut(v, u);

			for (int use = 0; use <= breads; ++ use) {
				pair<long long, int> cur = make_pair(dpOut[v][use][0].first, v);
				if (use > 0) {
					cur.first = max(cur.first, dpOut[v][use - 1][0].first + getCost(u, v));
				}
				update(dpOut[u][use], cur);
			}

			for (int use = 0; use <= breads; ++ use) {
				long long cur = dpIn[v][use];
				dpIn[u][use] = max(dpIn[u][use], cur);
				if (use < breads) {
					cur += getCost(u, p);
					dpIn[u][use + 1] = max(dpIn[u][use + 1], cur);
				}
			}
		}

		for (const auto& v : g[u]) {
			if (v == p) continue;
			
			for (int useOut = 0; useOut <= breads; ++ useOut) {
				int useIn = breads - useOut;
				for (int typeOut = 0; typeOut < 2; ++ typeOut) {
					if (dpOut[u][useOut][typeOut].second == v) continue;
					res = max(res, dpOut[u][useOut][typeOut].first + dpIn[v][useIn]);
				}
			}
		}
		for (int useOut = 0; useOut <= breads; ++ useOut) {
			res = max(res, dpOut[u][useOut][0].first);
		}
		for (int useIn = 0; useIn <= breads; ++ useIn) {
			res = max(res, dpIn[u][useIn]);
		}
	}

	void solve() {
		for (int u = 1; u <= n; ++ u) {
			for (const auto& v : g[u]) {
				cost[v] += pigeons[u];
			}
		}
		dfsInOut(1, 0);
		cout << res;
	}

}

int main() {
	ios_base::sync_with_stdio(false);
	cin.tie(nullptr); cout.tie(nullptr);
	cin >> n >> breads;
	for (int u = 1; u <= n; ++ u) cin >> pigeons[u];
	for (int i = 2; i <= n; ++ i) {
		int u, v; cin >> u >> v;
		g[u].push_back(v);
		g[v].push_back(u);
	}
	if (sub1::check()) {
		sub1::solve();
		return 0;
	}
	if (sub2::check()) {
		sub2::solve();
		return 0;
	}
	// cout << sub2::calc(1);
	sub4::solve();
	return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...