#include "closing.h"
#include <algorithm>
#include <set>
#include <vector>
using namespace std;
using ll = long long;
using Graph = vector<vector<pair<int,ll>>>;
void dfs(int u, int fu, const Graph& g, vector<ll>& dists, vector<int>& trace) {
for (auto [v, w] : g[u]) {
if (v == fu) continue;
assert(dists[v] < 0);
dists[v] = dists[u] + w;
trace[v] = u;
dfs(v, u, g, dists, trace);
}
}
pair<vector<ll>, vector<int>> get_dists(int u, const Graph& g) {
int n = g.size();
vector<ll> dists(n, -1);
vector<int> trace(n, -1);
dists[u] = 0;
dfs(u, -1, g, dists, trace);
return {dists, trace};
}
int no_common(ll k, const vector<ll>& dx, const vector<ll>& dy) {
vector<ll> all;
for (auto d : dx) all.push_back(d);
for (auto d : dy) all.push_back(d);
sort(all.begin(), all.end());
int cnt = 0;
for (auto d : all) {
if (d <= k) {
k -= d;
++cnt;
} else break;
}
return cnt;
}
// linear graph
struct Event {
ll cost;
int i;
int typ; // 1 = SINGLE, 2 = PAIR
};
bool operator < (const Event& a, const Event& b) {
return a.cost * (3 - a.typ) < b.cost * (3 - b.typ);
}
int sub9(int n, ll k, int x, int y,
const vector<ll>& dx, const vector<ll>& dy,
const vector<bool>& on_path) {
int res = no_common(k, dx, dy);
// 0 .. [xl .. x .. xr]
// [yl .. y .. yr] .. n-1
// Now we only consider the case where [xl, xr] and [yl, yr] overlap
// They must overlap at least one point between [x, y]
int cur = 0;
vector<int> taken(n, 0);
vector<Event> events;
// 1. For every vertex not in path x -> y
for (int i = 0; i < n; ++i) {
if (on_path[i]) continue;
int a = min(dx[i], dy[i]);
int b = max(dx[i], dy[i]);
if (b - a >= a) {
events.push_back({a, i, 1});
events.push_back({b - a, i, 1});
} else {
events.push_back({b, i, 2});
}
}
// 2. On path x -> y, every point must be visited by either x or y
for (int i = 0; i < n; ++i) {
if (!on_path[i]) continue;
k -= min(dx[i], dy[i]);
// Additionally, we can choose to visit from other point
++cur;
events.push_back({llabs(dx[i] - dy[i]), i, 1});
taken[i] = 1;
}
if (k >= 0) {
sort(events.begin(), events.end());
for (auto& event : events) {
if (event.cost > k) break;
k -= event.cost;
cur += event.typ;
taken[event.i] += event.typ;
assert(taken[event.i] <= 2);
event.typ = 0;
}
// Break "bundle"
set<pair<ll, int>> costs;
for (auto& event : events) {
if (event.typ == 2) costs.insert({min(dx[event.i], dy[event.i]), event.i});
}
for (auto [cost, i] : costs) if (k >= cost) {
k -= cost, ++cur;
taken[i] = 1;
}
multiset<ll> ones, twos;
for (int i = 0; i < n; ++i) {
if (taken[i] == 1 && !on_path[i]) {
ones.insert(min(dx[i], dy[i]));
twos.insert(max(dx[i], dy[i]));
} else if (taken[i] == 0) {
twos.insert(max(dx[i], dy[i]));
}
}
while (ones.size() > 0 && twos.size() > 0) {
auto it_one = std::prev(ones.end());
auto it_two = twos.begin();
auto cost = *it_two - *it_one;
if (cost <= k) {
k -= cost;
++cur;
ones.erase(it_one);
twos.erase(it_two);
} else break;
}
res = max(res, cur);
}
return res;
}
int max_score(int n, int X, int Y, long long K,
vector<int> U, vector<int> V, vector<int> W) {
Graph g(n);
assert(int(U.size()) == n-1);
assert(int(V.size()) == n-1);
assert(int(W.size()) == n-1);
for (int i = 0; i < n - 1; i++) {
int u = U[i];
int v = V[i];
int w = W[i];
g[u].emplace_back(v, w);
g[v].emplace_back(u, w);
}
if (X > Y) swap(X, Y);
auto [distsX, traceX] = get_dists(X, g);
auto [distsY, traceY] = get_dists(Y, g);
vector<bool> on_path(n, false);
int cur = Y;
while (cur != X) {
on_path[cur] = true;
cur = traceX[cur];
}
on_path[X] = true;
return sub9(n, K, X, Y, distsX, distsY, on_path);
}
Compilation message
closing.cpp: In function 'void dfs(int, int, const Graph&, std::vector<long long int>&, std::vector<int>&)':
closing.cpp:13:9: error: 'assert' was not declared in this scope
13 | assert(dists[v] < 0);
| ^~~~~~
closing.cpp:4:1: note: 'assert' is defined in header '<cassert>'; did you forget to '#include <cassert>'?
3 | #include <set>
+++ |+#include <cassert>
4 | #include <vector>
closing.cpp: In function 'int sub9(int, ll, int, int, const std::vector<long long int>&, const std::vector<long long int>&, const std::vector<bool>&)':
closing.cpp:99:13: error: 'assert' was not declared in this scope
99 | assert(taken[event.i] <= 2);
| ^~~~~~
closing.cpp:99:13: note: 'assert' is defined in header '<cassert>'; did you forget to '#include <cassert>'?
closing.cpp: In function 'int max_score(int, int, int, long long int, std::vector<int>, std::vector<int>, std::vector<int>)':
closing.cpp:140:5: error: 'assert' was not declared in this scope
140 | assert(int(U.size()) == n-1);
| ^~~~~~
closing.cpp:140:5: note: 'assert' is defined in header '<cassert>'; did you forget to '#include <cassert>'?