답안 #682164

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
682164 2023-01-16T02:46:34 Z GusterGoose27 Chase (CEOI17_chase) C++11
30 / 100
928 ms 172356 KB
#include <bits/stdc++.h>

using namespace std;

typedef long long ll;

const int MAXN = 1e5;
const int MAXV = 101;
ll dp[2][MAXN][MAXV]; // away or towards
int val[MAXN];
ll adj_sum[MAXN];
vector<int> edges[MAXN];
bool vis[MAXN];
int sz[MAXN];
int n, v;
ll ans = 0;

void get_sz(int cur, int p = -1) {
	sz[cur] = 1;
	if (p != -1) {
		for (int j = 1; j <= v; j++) {
			dp[0][cur][j] = adj_sum[cur]-val[p];
			dp[1][cur][j] = adj_sum[cur];
		}
	}
	for (int nxt: edges[cur]) {
		if (nxt == p || vis[nxt]) continue;
		get_sz(nxt, cur);
		sz[cur] += sz[nxt];
		if (p != -1) {
			for (int j = 1; j <= v; j++) {
				dp[0][cur][j] = max(dp[0][cur][j], max(dp[0][nxt][j], dp[0][nxt][j-1]+adj_sum[cur]-val[p]));
				dp[1][cur][j] = max(dp[1][cur][j], max(dp[1][nxt][j], dp[1][nxt][j-1]+adj_sum[cur]-val[nxt]));
			}
		}
	}
}

void decomp(int cur) {
	get_sz(cur);
	fill(dp[1][cur]+1, dp[1][cur]+v+1, adj_sum[cur]);
	for (int nxt: edges[cur]) {
		if (vis[nxt]) continue;
		for (int i = 0; i <= v; i++) {
			ans = max(ans, dp[1][cur][i]+dp[0][nxt][v-i]);
			if (i) dp[1][cur][i] = max(dp[1][cur][i], max(dp[1][nxt][i], dp[1][nxt][i-1]+adj_sum[cur]-val[nxt]));
		}
	}
	fill(dp[1][cur]+1, dp[1][cur]+v+1, adj_sum[cur]);
	for (int j = edges[cur].size()-1; j >= 0; j--) {
		int nxt = edges[cur][j];
		if (vis[nxt]) continue;
		for (int i = 0; i <= v; i++) {
			ans = max(ans, dp[1][cur][i]+dp[0][nxt][v-i]);
			if (i) dp[1][cur][i] = max(dp[1][cur][i], max(dp[1][nxt][i], dp[1][nxt][i-1]+adj_sum[cur]-val[nxt]));
		}
	}
	ans = max(ans, dp[1][cur][v]);
	int p = -1;
	int tot_sz = sz[cur];
	bool found = 1;
	while (found) {
		found = 0;
		for (int nxt: edges[cur]) {
			if (nxt == p || vis[nxt]) continue;
			if (sz[nxt] > tot_sz/2) {
				found = 1;
				p = cur;
				cur = nxt;
				break;
			}
		}
	}
	vis[cur] = 1;
	// cerr << cur << " "<< ans << "\n";
	for (int nxt: edges[cur]) if (!vis[nxt]) decomp(nxt);
}

int main() {
	ios_base::sync_with_stdio(false); cin.tie(NULL);
	cin >> n >> v;
	for (int i = 0; i < n; i++) cin >> val[i];
	for (int i = 0; i < n-1; i++) {
		int a, b; cin >> a >> b;
		a--; b--;
		edges[a].push_back(b);
		edges[b].push_back(a);
		adj_sum[a] += val[b];
		adj_sum[b] += val[a];
	}
	decomp(0);
	cout << ans << "\n";
}
# 결과 실행 시간 메모리 Grader output
1 Correct 2 ms 2644 KB Output is correct
2 Incorrect 2 ms 2644 KB Output isn't correct
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 2 ms 2644 KB Output is correct
2 Incorrect 2 ms 2644 KB Output isn't correct
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 912 ms 170656 KB Output is correct
2 Correct 928 ms 172356 KB Output is correct
3 Correct 141 ms 168056 KB Output is correct
4 Correct 181 ms 167756 KB Output is correct
5 Correct 723 ms 167764 KB Output is correct
6 Correct 746 ms 167756 KB Output is correct
7 Correct 695 ms 167764 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 2 ms 2644 KB Output is correct
2 Incorrect 2 ms 2644 KB Output isn't correct
3 Halted 0 ms 0 KB -