#include "closing.h"
#include <bits/stdc++.h>
#define maxn 200005
#define fi first
#define se second
using namespace std;
using ii = pair<int, int>;
int n, x, y; int64_t k;
int64_t disX[maxn], disY[maxn];
int64_t cost_level_one[maxn], cost_level_two[maxn];
vector<ii> adj[maxn];
void calcdisX() {
function<void(int, int)> dfs = [&](int u, int dad) {
for (auto [v, l] : adj[u])
if (v != dad) {
disX[v] = disX[u] + l;
dfs(v, u);
}
};
disX[x] = 0;
dfs(x, -1);
}
void calcdisY() {
function<void(int, int)> dfs = [&](int u, int dad) {
for (auto [v, l] : adj[u])
if (v != dad) {
disY[v] = disY[u] + l;
dfs(v, u);
}
};
disY[y] = 0;
dfs(y, -1);
}
int sub1() {
vector<int64_t> nho;
for (int i = 0; i < n; i++) nho.emplace_back(min(disX[i], disY[i]));
sort(nho.begin(), nho.end());
int64_t ans = 0;
for (int i = 0; i < n; i++) {
ans += nho[i];
if (ans > k) return i;
}
return n;
}
int sub2() {
int ans = 0;
for (int i = 0; i < n; i++) {
cost_level_one[i] = min(disX[i], disY[i]);
cost_level_two[i] = max(disX[i], disY[i]) - cost_level_one[i];
}
int64_t sum = 0;
multiset<ii> q1, q2, q3;
for (int i = 0; i < n; i++)
if (disX[i] + disY[i] == disX[y]) {
++ans;
sum += cost_level_one[i];
q1.insert({cost_level_two[i], i});
} else {
if (cost_level_one[i] > cost_level_two[i]) {
q2.insert({cost_level_one[i] + cost_level_two[i], i});
q3.insert({cost_level_one[i], i});
} else {
q1.insert({cost_level_one[i], i});
q1.insert({cost_level_two[i], i});
}
}
if (sum > k) return 0;
while (1) {
if (q1.empty() and q2.empty()) break;
if (sum + (q1.empty() ? LLONG_MAX/2 : q1.begin()->fi) > k and
sum + (q2.empty() ? LLONG_MAX/2 : q2.begin()->fi) > k) {
if (sum + (q3.empty() ? LLONG_MAX/2 : q3.begin()->fi) > k) break;
sum += q3.begin()->fi; break;
}
if (q1.size() >= 2 && !q2.empty()) {
int s1 = q1.begin()->fi + next(q1.begin())->fi;
if (s1 < q2.begin()->fi) {
if (sum + q1.begin()->fi > k) {
if (sum + q3.begin()->fi <= k) {
sum += q3.begin()->fi;
++ans;
}
break;
}
sum += q1.begin()->fi;
++ans;
q1.erase(q1.begin());
continue;
}
if (sum + q2.begin()->fi > k) {
if (sum + q1.begin()->fi + q3.begin()->fi <= k) {
ans += 2;
break;
} else if (sum + min(q1.begin()->fi, q3.begin()->fi) <= k) {
++ans;
break;
}
break;
}
sum += q2.begin()->fi; ans += 2;
q3.erase(ii{cost_level_one[q2.begin()->se], q2.begin()->se});
q2.erase(q2.begin());
continue;
}
if (q2.empty()) {
if (sum + q1.begin()->fi > k) break;
sum += q1.begin()->fi; ++ans;
q1.erase(q1.begin());
continue;
}
if (sum + q2.begin()->fi > k) {
if (q1.empty()) {
if (sum + q3.begin()->fi <= k) ++ans;
break;
}
if (sum + q1.begin()->fi + q3.begin()->fi <= k) {
ans += 2;
break;
} else if (sum + min(q1.begin()->fi, q3.begin()->fi) <= k) {
++ans;
break;
}
break;
}
sum += q2.begin()->fi; ans += 2;
q3.erase(ii{cost_level_one[q2.begin()->se], q2.begin()->se});
q2.erase(q2.begin());
}
return ans;
}
/*
1
5 2 3 91
1 0 25
2 1 16
3 2 71
4 3 62
//34 + 49
4
*/
int max_score(int N, int X, int Y, long long K, vector<int> U, vector<int> V, vector<int> W) {
n = N; x = X; y = Y; k = K;
for (int i = 0; i < N-1; i++) {
adj[U[i]].emplace_back(V[i], W[i]);
adj[V[i]].emplace_back(U[i], W[i]);
}
calcdisX(); calcdisY();
int ans = max(sub1(), sub2());
for (int i = 0; i < N; i++) adj[i].clear();
return ans;
}
/*
1
5 1 2 92
1 0 75
2 1 2
3 1 88
4 3 54
*/
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |