#include <bits/stdc++.h>
using namespace std;
int n, k, X;
vector<vector<int>> dp1[2];
vector<int> num;
vector<vector<pair<int, int>>> adj;
void solve2(int x, int p) {
vector<vector<int>> dp2[2];
for (auto v: adj[x]) {
if (v.first == p) continue;
solve2(v.first, x);
}
dp1[0][x].push_back(0);
dp1[1][x].push_back(0);
dp2[0].push_back({});
dp2[1].push_back({});
dp2[0][0].push_back(0);
dp2[1][0].push_back(0);
int s = 0;
int cnt = 0;
for (auto v: adj[x]) {
if (v.first == p) continue;
dp2[0].push_back({});
dp2[1].push_back({});
for (int i = 0; i <= min(s + num[v.first], k - 1); i++) {
for (int j = max(0, i - num[v.first]); j <= min({s, k - 1, i}); j++) {
if (j == max(0, i - num[v.first])) {
if (i != j) {
dp2[0].back().push_back(min(dp2[0][cnt][j] + dp1[1][v.first][i - j] + 2 * v.second, dp2[1][cnt][j] + dp1[0][v.first][i - j]) + v.second);
dp2[1].back().push_back(dp2[1][cnt][j] + dp1[1][v.first][i - j] + 2 * v.second);
} else {
dp2[0].back().push_back(dp2[0][cnt][j]);
dp2[1].back().push_back(dp2[1][cnt][j]);
}
} else {
if (i != j) {
dp2[0][cnt + 1][i] = min(dp2[0][cnt + 1][i], min(dp2[0][cnt][j] + dp1[1][v.first][i - j] + 2 * v.second, dp2[1][cnt][j] + dp1[0][v.first][i - j] + v.second));
dp2[1][cnt + 1][i] = min(dp2[1][cnt + 1][i], dp2[1][cnt][j] + dp1[1][v.first][i - j] + 2 * v.second);
} else {
dp2[0][cnt + 1][i] = min(dp2[0][cnt + 1][i], dp2[0][cnt][j]);
dp2[1][cnt + 1][i] = min(dp2[1][cnt + 1][i], dp2[1][cnt][j]);
}
}
}
}
s += num[v.first];
cnt++;
}
for (int i = 0; i < (int)dp2[0][cnt].size(); i++) {
dp1[0][x].push_back(dp2[0][cnt][i]);
}
for (int i = 0; i < (int)dp2[1][cnt].size(); i++) {
dp1[1][x].push_back(dp2[1][cnt][i]);
}
}
int solve() {
vector<vector<int>> dp2[2];
for (auto v: adj[X]) {
solve2(v.first, X);
}
dp2[0].push_back({});
dp2[1].push_back({});
dp2[0][0].push_back(0);
dp2[1][0].push_back(0);
int s = 0;
int cnt = 0;
for (auto v: adj[X]) {
dp2[0].push_back({});
dp2[1].push_back({});
for (int i = 0; i <= min(s + num[v.first], k - 1); i++) {
for (int j = max(0, i - num[v.first]); j <= min({s, k - 1, i}); j++) {
if (j == max(0, i - num[v.first])) {
if (i != j) {
dp2[0].back().push_back(min(dp2[0][cnt][j] + dp1[1][v.first][i - j] + 2 * v.second, dp2[1][cnt][j] + dp1[0][v.first][i - j]) + v.second);
dp2[1].back().push_back(dp2[1][cnt][j] + dp1[1][v.first][i - j] + 2 * v.second);
} else {
dp2[0].back().push_back(dp2[0][cnt][j]);
dp2[1].back().push_back(dp2[1][cnt][j]);
}
} else {
if (i != j) {
dp2[0][cnt + 1][i] = min(dp2[0][cnt + 1][i], min(dp2[0][cnt][j] + dp1[1][v.first][i - j] + 2 * v.second, dp2[1][cnt][j] + dp1[0][v.first][i - j] + v.second));
dp2[1][cnt + 1][i] = min(dp2[1][cnt + 1][i], dp2[1][cnt][j] + dp1[1][v.first][i - j] + 2 * v.second);
} else {
dp2[0][cnt + 1][i] = min(dp2[0][cnt + 1][i], dp2[0][cnt][j]);
dp2[1][cnt + 1][i] = min(dp2[1][cnt + 1][i], dp2[1][cnt][j]);
}
}
}
}
s += num[v.first];
cnt++;
}
return dp2[0][cnt][k - 1];
}
void init(int x, int p = -1) {
num[x] = 1;
for (auto v: adj[x]) {
if (v.first == p) continue;
init(v.first, x);
num[x] += num[v.first];
}
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
cin >> n >> k >> X;
adj.resize(n);
dp1[0].resize(n);
dp1[1].resize(n);
num.resize(n);
X--;
for (int i = 0; i < n - 1; i++) {
int a, b, c;
cin >> a >> b >> c;
a--; b--;
adj[a].push_back({b, c});
adj[b].push_back({a, c});
}
init(X);
cout << solve() << "\n";
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
1 ms |
348 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
15 ms |
2908 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
15 ms |
2908 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
1 ms |
348 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |