#include <bits/stdc++.h>
#include "closing.h"
// #include "grader.cpp"
using namespace std;
typedef long long ll;
const int N = 2e5 + 10;
int n, x[2];
ll k, dist[N][2], pref[N][2];
vector<pair<int, int>> g[N];
void dfs(int v, int id, int p = -1){
for (auto [w, u] : g[v]){
if (u == p) continue;
dist[u][id] = dist[v][id] + w;
dfs(u, id, v);
}
}
int max_score(int nn, int xx, int yy, ll kk, vector<int> uu, vector<int> vv, vector<int> ww){
xx++, yy++;
for (int i = 0; i < nn - 1; i ++)
uu[i]++, vv[i]++;
n = nn, x[0] = xx, x[1] = yy, k = kk;
for (int i = 0; i < n - 1; i ++){
g[uu[i]].push_back({ww[i], vv[i]});
g[vv[i]].push_back({ww[i], uu[i]});
}
multiset<ll> st[2];
for (int id : {0, 1}){
dist[x[id]][id] = 0;
dfs(x[id], id);
for (int i = 1; i <= n; i ++){
pref[i][id] = pref[i - 1][id] + dist[i][id];
st[id].insert(dist[i][id]);
}
}
int ans = 0;
while (st[0].size() and st[1].size()){
ll a = *st[0].begin(), b = *st[1].begin();
if (a <= b){
if (k < a) break;
k -= a;
st[0].erase(st[0].begin());
ans++;
continue;
}
if (k < b) break;
k -= b;
st[1].erase(st[1].begin());
ans++;
continue;
}
for (int mid = x[0]; mid <= x[1]; mid ++){
k = kk;
set<pair<ll, int>> st[2];
int vis[n + 1] = {}, cur = 0;
memset(vis, 0, sizeof vis);
for (int i = 1; i < x[0]; i ++)
for (int id : {0, 1})
st[id].insert({dist[i][id], i});
for (int i = x[0]; i < mid; i ++){
vis[i] = 1;
k -= dist[i][0];
st[1].insert({dist[i][1] - dist[i][0], i});
}
k -= max(dist[mid][0], dist[mid][1]);
vis[mid] = 3;
for (int i = mid + 1; i <= x[1]; i ++){
vis[i] = 2;
k -= dist[i][1];
st[0].insert({dist[i][0] - dist[i][1], i});
}
for (int i = x[1] + 1; i <= n; i ++)
for (int id : {0, 1})
st[id].insert({dist[i][id], i});
if (k < 0) continue;
cur = x[1] - x[0] + 2;
int ite = 0;
while (!st[0].empty() or !st[1].empty()){
ite++;
if (ite > 1e6) return 1/0;
ll d1, v1, d2, v2;
if (st[0].empty()){
d2 = (*st[1].begin()).first;
v2 = (*st[1].begin()).second;
d1 = 1e18;
v1 = -1;
}
else if (st[1].empty()){
d1 = (*st[0].begin()).first;
v1 = (*st[0].begin()).second;
d2 = 1e18;
v2 = -1;
}
else{
d1 = (*st[0].begin()).first;
v1 = (*st[0].begin()).second;
d2 = (*st[1].begin()).first;
v2 = (*st[1].begin()).second;
}
if (d1 <= d2){
if (k < d1) break;
k -= d1;
cur++;
st[0].erase(st[0].begin());
vis[v1] |= 1;
if (vis[v1] & 2) continue;
if (st[1].find({dist[v1][1], v1}) == st[1].end()) continue;
st[1].erase({dist[v1][1], v1});
st[1].insert({dist[v1][1] - dist[v1][0], v1});
}
else{
if (k < d2) break;
k -= d2;
cur++;
st[1].erase(st[1].begin());
vis[v2] |= 2;
if (vis[v2] & 1) continue;
if (st[0].find({dist[v2][0], v2}) == st[0].end()) continue;
st[0].erase({dist[v2][0], v2});
st[0].insert({dist[v2][0] - dist[v2][1], v2});
}
}
ans = max(ans, cur);
}
for (int i = 0; i <= n; i ++) g[i].clear();
return ans;
}
Compilation message (stderr)
closing.cpp: In function 'int max_score(int, int, int, ll, std::vector<int>, std::vector<int>, std::vector<int>)':
closing.cpp:90:36: warning: division by zero [-Wdiv-by-zero]
90 | if (ite > 1e6) return 1/0;
| ~^~
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |