#include "closing.h"
#include <bits/stdc++.h>
#define maxn 200005
#define fi first
#define se second
using namespace std;
using ii = pair<int, int>;
int n, x, y; int64_t k;
int64_t disX[maxn], disY[maxn];
int64_t cost_level_one[maxn], cost_level_two[maxn];
vector<ii> adj[maxn];
void calcdisX() {
function<void(int, int)> dfs = [&](int u, int dad) {
for (auto [v, l] : adj[u])
if (v != dad) {
disX[v] = disX[u] + l;
dfs(v, u);
}
};
disX[x] = 0;
dfs(x, -1);
}
void calcdisY() {
function<void(int, int)> dfs = [&](int u, int dad) {
for (auto [v, l] : adj[u])
if (v != dad) {
disY[v] = disY[u] + l;
dfs(v, u);
}
};
disY[y] = 0;
dfs(y, -1);
}
int sub1() {
vector<int64_t> nho;
for (int i = 0; i < n; i++) nho.emplace_back(min(disX[i], disY[i]));
sort(nho.begin(), nho.end());
int64_t ans = 0;
for (int i = 0; i < n; i++) {
ans += nho[i];
if (ans > k) return i;
}
return n;
}
int sub2() {
int ans = 0;
for (int i = 0; i < n; i++) {
cost_level_one[i] = min(disX[i], disY[i]);
cost_level_two[i] = max(disX[i], disY[i]) - cost_level_one[i];
}
int64_t sum = 0;
vector<int> type(n, 0);
priority_queue<ii, vector<ii>, greater<ii> > q;
for (int i = 0; i < n; i++)
if (disX[i] + disY[i] == disX[y]) {
type[i] = 1;
++ans;
sum += cost_level_one[i];
q.push(ii{cost_level_two[i], i});
} else {
if (cost_level_one[i] > cost_level_two[i]) {
q.push(ii{cost_level_one[i] + cost_level_two[i], i});
type[i] = 2;
} else {
type[i] = 1;
q.push(ii{cost_level_one[i], i});
}
}
if (sum > k) return 0;
while (!q.empty()) {
if (sum + q.top().fi > k) break;
sum += q.top().fi; ++ans;
int u = q.top().se; q.pop();
if (type[u] == 2) {
type[u] = 3;
continue;
}
if (type[u] == 1) {
q.push(ii{cost_level_two[u], u});
type[u] = 2;
}
}
for (int i = 0; i < n; i++)
if (disX[i] + disY[i] != disX[y] && cost_level_one[i] > cost_level_two[i] && sum + cost_level_one[i] <= k && type[i] == 2) {
sum += cost_level_one[i];
++ans;
}
return ans;
}
int max_score(int N, int X, int Y, long long K, vector<int> U, vector<int> V, vector<int> W) {
n = N; x = X; y = Y; k = K;
for (int i = 0; i < N-1; i++) {
adj[U[i]].emplace_back(V[i], W[i]);
adj[V[i]].emplace_back(U[i], W[i]);
}
calcdisX(); calcdisY();
int ans = max(sub1(), sub2());
for (int i = 0; i < N; i++) adj[i].clear();
return ans;
}
/*
2
7 0 2 10
0 1 2
0 3 3
1 2 4
2 4 2
2 5 5
5 6 3
4 0 3 20
0 1 18
1 2 1
2 3 19
*/
/*
1
4 0 3 20
0 1 18
1 2 1
2 3 19
*/
//6 3
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |