#include "bits/stdc++.h"
#include "closing.h"
#define rep(i, n) for(int i = 0; i < (n); ++i)
#define all(a) (a).begin(), (a).end()
#define rall(a) (a).rbegin(), (a).rend()
#define ar array
using namespace std;
typedef long long ll;
using pi = pair<int, int>;
using vi = vector<int>;
using vl = vector<ll>;
using vpi = vector<pi>;
const ll INF = 4e18;
void my_assert(bool c) {
if (c) return;
int x = 0;
while (true) {
x++;
if (x == 10) x = -5 * x;
}
}
struct SegTree {
vector<ll> t;
int n;
void build(int _n) {
n = 2;
while (n <= _n) n *= 2;
t.resize(n * 2);
}
void upd(int i, ll x) {
for (t[i += n] += x; i > 1; i >>= 1) t[i >> 1] = t[i] + t[i ^ 1];
}
ll get(int l, int r) {
ll s = 0;
for (l += n, r += n; l < r; l >>= 1, r >>= 1) {
if (l & 1) s += t[l++];
if (r & 1) s += t[--r];
}
return s;
}
};
pair<vl, vi> BuildC(int n, int x, vector<vpi> g) {
vector<ll> c(n, INF);
vector<int> p(n, -1);
c[x] = 0;
queue<int> q;
q.push(x);
while (!q.empty()) {
int v = q.front();
q.pop();
for (auto [u, w]: g[v]) {
if (c[u] > w + c[v]) {
my_assert(c[u] == INF);
c[u] = w + c[v];
p[u] = v;
q.push(u);
}
}
}
return {c, p};
}
int max_score(int n, int x, int y, ll k, vi U, vi V, vi W) {
k *= 2;
int ans = 2;
vector<vpi> g(n);
rep(i, U.size()) {
int u = U[i];
int v = V[i];
int w = W[i];
w *= 2;
my_assert(w > 0);
g[u].emplace_back(v, w);
g[v].emplace_back(u, w);
}
auto [cx, px] = BuildC(n, x, g);
auto [cy, py] = BuildC(n, y, g);
{
auto sx = cx, sy = cy;
sort(all(sx));
sort(all(sy));
ll sum = accumulate(all(sy), 0ll);
int j = n;
for (int i = 0; i <= n; ++i) {
if (i) sum += sx[i - 1];
while (j > 0 && sum > k) {
j--;
sum -= sy[j];
}
if (sum > k) break;
ans = max(ans, i + j);
}
}
vector<int> mids;
{
vector<int> path;
for(int v = x; v != -1; v = py[v]) path.push_back(v);
rep(i, path.size()) {
int v = path[i];
if (i + 1 < path.size() && cy[path[i + 1]] > cx[path[i + 1]]) continue;
if (i > 0 && cx[path[i - 1]] > cy[path[i - 1]]) continue;
mids.push_back(v);
}
}
// my_assert(mids.size() <= 4);
for (auto mid: mids) {
// if (true) {
// my_assert(mid >= x && mid <= y);
// int l = x;
// int r = y;
// while (true) {
// ll s = 0;
// for (int i = l; i <= r; ++i) s += min(cx[i], cy[i]);
// s += abs(cx[mid] - cy[mid]);
// if (s > k) break;
// int score = r - l + 1 + 1;
// vector<ll> can;
// for (int i = l; i <= r; ++i) {
// if (i == mid) continue;
// can.push_back(abs(cx[i] - cy[i]));
// }
// ans = max(ans, score);
// sort(all(can));
// for (auto &v: can) {
// s += v;
// if (s > k) break;
// score++;
// ans = max(ans, score);
// }
// if (l == 0 && r == n - 1) break;
// if (l == 0) {
// r++;
// } else if (r == n - 1) {
// l--;
// } else if (min(cx[l - 1], cy[l - 1]) > min(cx[r + 1], cy[r + 1])) {
// r++;
// } else {
// l--;
// }
// }
// continue;
// }
vector<int> state(n);
ll use = max(cx[mid], cy[mid]);
state[mid] = 2;
for (int vx = px[mid]; vx != -1; vx = px[vx]) {
my_assert(state[vx] == 0);
my_assert(cx[vx] < cy[vx]);
state[vx] = 1;
use += cx[vx];
}
for (int vy = py[mid]; vy != -1; vy = py[vy]) {
my_assert(state[vy] == 0);
my_assert(cx[vy] > cy[vy]);
state[vy] = 1;
use += cy[vy];
}
if (use > k) continue;
my_assert(state[x] && state[y]);
int score_default = accumulate(all(state), 0);
ans = max(ans, score_default);
vector<ll> xx;
xx.push_back(-INF);
xx.push_back(INF);
vector<ll> diff(n), mx(n), mn(n);
rep(i, n) {
diff[i] = abs(cx[i] - cy[i]);
xx.push_back(diff[i]);
mn[i] = min(cx[i], cy[i]);
mx[i] = max(cx[i], cy[i]);
my_assert(mx[i] % 2 == 0);
xx.push_back(mx[i] / 2);
}
sort(all(xx));
xx.resize(unique(all(xx)) - xx.begin());
auto GetInd = [&](ll x) {
int i = lower_bound(all(xx), x) - xx.begin();
my_assert(i < xx.size() && xx[i] == x);
return i;
};
int sz = xx.size();
SegTree st1, st2, ct1, ct2;
st1.build(sz);
st2.build(sz);
ct1.build(sz);
ct2.build(sz);
vi ord2;
vi ord1;
rep(i, n) {
if (state[i] == 2) continue;
if (state[i] == 1) {
st1.upd(GetInd(diff[i]), diff[i]);
ct1.upd(GetInd(diff[i]), 1);
continue;
}
// used[i] = false;
ord2.push_back(i);
ord1.push_back(i);
// if (mn[i] > diff[i]) {
// st2.upd(GetInd(mx[i] / 2), mx[i]);
// ct2.upd(GetInd(mx[i] / 2), 1);
// }
}
// rep(v, n) if (state[v]) {
// for (auto &[u, w]: g[v]) {
// if (state[u] == 0 && !used[u]) {
// dfs(u);
// }
// }
// }
sort(all(ord1),
[&](const int &i, const int &j) { return make_pair(mn[i], diff[i]) < make_pair(mn[j], diff[j]); });
sort(all(ord2),
[&](const int &i, const int &j) { return diff[i] < diff[j]; });
auto FindNext = [&](int l) {
// for(int i = l; i < sz; ++i) {
// if (ct1.get(i, i + 1)) return i;
// }
// return sz;
int L = l - 1;
int R = sz;
while (R - L > 1) {
int mid = (L + R) / 2;
if (ct1.get(l, mid + 1)) {
R = mid;
} else {
L = mid;
}
}
my_assert(ct1.get(l, R) == 0);
if (R != sz) my_assert(ct1.get(R, R + 1) > 0);
return R;
};
auto Check = [&]() {
if (use > k) return;
// vector<ll> val2, val1;
// rep(i, n) {
// if (state[i] == 1) {
// val1.push_back(diff[i]);
// }
// if (state[i] == 0) {
// val2.push_back(mx[i]);
// }
// }
// int score = score_default;
// sort(all(val1));
// sort(all(val2));
// int i = 0;
// int j = 0;
// ll have = k - use;
// while (true) {
// ans = max(ans, score);
// ll to = INF;
// if (i < val1.size() && val1[i] <= have) to = min(to, val1[i]);
// if (j < val2.size() && val2[j] <= have) to = min(to, val2[j] / 2);
// if (j < val2.size() && val2[j] <= have && to == val2[j] / 2) {
// have -= val2[j];
// score += 2;
// j++;
// } else if (i < val1.size() && val1[i] <= have && to == val1[i]) {
// have -= val1[i];
// score++;
// i++;
// } else {
// my_assert(to == INF);
// break;
// }
// }
// return;
// int l = -1;
// int r = 0;
// while (r < sz && use + st1.get(0, r + 1) + st2.get(0, r + 1) <= k) {
// r++;
// l++;
// }
int l = -1;
int r = sz;
while (r - l > 1) {
int mid = (l + r) / 2;
ll s = st1.get(0, mid + 1) + st2.get(0, mid + 1);
if (s + use <= k) {
l = mid;
} else {
r = mid;
}
}
my_assert(l != -1);
int score = score_default + ct1.get(0, r) + ct2.get(0, r) * 2;
if (r == sz) {
// my_assert(score == 2 * n);
ans = max(ans, score);
return;
}
int cnt2 = ct2.get(r, r + 1);
int cnt1 = ct1.get(r, r + 1);
int nxt = FindNext(r + 1);
int score2 = score;
for(int t = 0; t <= 1; ++t) {
score = score2;
if (t && nxt == sz) continue;
ll have = k - use - st1.get(0, r) - st2.get(0, r);
if (t && xx[nxt] <= have) {
have -= xx[nxt];
score++;
}
my_assert(have >= 0);
if (cnt2) {
ll t2 = min(1ll * cnt2, have / (2 * xx[r]));
cnt2 -= t2;
have -= t2 * xx[r] * 2;
score += t2 * 2;
}
if (cnt1) {
ll t1 = min(1ll * cnt1, have / xx[r]);
cnt1 -= t1;
have -= t1 * xx[r];
score += t1;
}
my_assert(have >= 0);
ans = max(ans, score);
}
};
vector<bool> ok(n);
int uk2 = 0;
auto Check2 = [&] (ll value) {
while (uk2 < ord2.size() && diff[ord2[uk2]] <= value) {
int i = ord2[uk2++];
if (state[i] == 0) {
ok[i] = true;
st2.upd(GetInd(mx[i] / 2), mx[i]);
ct2.upd(GetInd(mx[i] / 2), 1);
}
}
Check();
};
for (auto &i: ord1) {
Check2(mn[i] - 1);
my_assert(state[i] == 0);
state[i] = 1;
use += mn[i];
if (use > k) break;
score_default++;
if (ok[i]) {
st2.upd(GetInd(mx[i] / 2), -mx[i]);
ct2.upd(GetInd(mx[i] / 2), -1);
}
st1.upd(GetInd(diff[i]), diff[i]);
ct1.upd(GetInd(diff[i]), 1);
Check();
}
Check2(INF);
}
my_assert(ans <= 2 * n);
return ans;
}
Compilation message
closing.cpp: In function 'int max_score(int, int, int, ll, vi, vi, vi)':
closing.cpp:4:36: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
4 | #define rep(i, n) for(int i = 0; i < (n); ++i)
| ^
closing.cpp:78:5: note: in expansion of macro 'rep'
78 | rep(i, U.size()) {
| ^~~
closing.cpp:4:36: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
4 | #define rep(i, n) for(int i = 0; i < (n); ++i)
| ^
closing.cpp:109:9: note: in expansion of macro 'rep'
109 | rep(i, path.size()) {
| ^~~
closing.cpp:111:23: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
111 | if (i + 1 < path.size() && cy[path[i + 1]] > cx[path[i + 1]]) continue;
| ~~~~~~^~~~~~~~~~~~~
closing.cpp: In lambda function:
closing.cpp:189:25: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<long long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
189 | my_assert(i < xx.size() && xx[i] == x);
| ~~^~~~~~~~~~~
closing.cpp: In lambda function:
closing.cpp:338:24: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
338 | while (uk2 < ord2.size() && diff[ord2[uk2]] <= value) {
| ~~~~^~~~~~~~~~~~~
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Execution timed out |
1068 ms |
348 KB |
Time limit exceeded |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
164 ms |
35664 KB |
Output is correct |
2 |
Correct |
203 ms |
34892 KB |
Output is correct |
3 |
Correct |
96 ms |
3096 KB |
Output is correct |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
0 ms |
344 KB |
Output is correct |
2 |
Correct |
0 ms |
348 KB |
Output is correct |
3 |
Correct |
0 ms |
600 KB |
Output is correct |
4 |
Correct |
1 ms |
348 KB |
Output is correct |
5 |
Execution timed out |
1076 ms |
348 KB |
Time limit exceeded |
6 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
0 ms |
344 KB |
Output is correct |
2 |
Correct |
0 ms |
348 KB |
Output is correct |
3 |
Correct |
0 ms |
600 KB |
Output is correct |
4 |
Correct |
1 ms |
348 KB |
Output is correct |
5 |
Execution timed out |
1076 ms |
348 KB |
Time limit exceeded |
6 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
0 ms |
344 KB |
Output is correct |
2 |
Correct |
0 ms |
348 KB |
Output is correct |
3 |
Correct |
0 ms |
600 KB |
Output is correct |
4 |
Correct |
1 ms |
348 KB |
Output is correct |
5 |
Execution timed out |
1076 ms |
348 KB |
Time limit exceeded |
6 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Execution timed out |
1068 ms |
348 KB |
Time limit exceeded |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Execution timed out |
1068 ms |
348 KB |
Time limit exceeded |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Execution timed out |
1068 ms |
348 KB |
Time limit exceeded |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Execution timed out |
1068 ms |
348 KB |
Time limit exceeded |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Execution timed out |
1068 ms |
348 KB |
Time limit exceeded |
2 |
Halted |
0 ms |
0 KB |
- |