Submission #841628

# Submission time Handle Problem Language Result Execution time Memory
841628 2023-09-01T18:50:25 Z Ormlis Closing Time (IOI23_closing) C++17
8 / 100
1000 ms 35664 KB
#include "bits/stdc++.h"
#include "closing.h"

#define rep(i, n) for(int i = 0; i < (n); ++i)
#define all(a) (a).begin(), (a).end()
#define rall(a) (a).rbegin(), (a).rend()
#define ar array

using namespace std;

typedef long long ll;

using pi = pair<int, int>;
using vi = vector<int>;
using vl = vector<ll>;
using vpi = vector<pi>;

const ll INF = 4e18;

void my_assert(bool c) {
    if (c) return;
    int x = 0;
    while (true) {
        if (x == 10) x = -5 * x;

struct SegTree {
    vector<ll> t;
    int n;

    void build(int _n) {
        n = 2;
        while (n <= _n) n *= 2;
        t.resize(n * 2);

    void upd(int i, ll x) {
        for (t[i += n] += x; i > 1; i >>= 1) t[i >> 1] = t[i] + t[i ^ 1];

    ll get(int l, int r) {
        ll s = 0;
        for (l += n, r += n; l < r; l >>= 1, r >>= 1) {
            if (l & 1) s += t[l++];
            if (r & 1) s += t[--r];
        return s;

pair<vl, vi> BuildC(int n, int x, vector<vpi> g) {
    vector<ll> c(n, INF);
    vector<int> p(n, -1);
    c[x] = 0;
    queue<int> q;
    while (!q.empty()) {
        int v = q.front();
        for (auto [u, w]: g[v]) {
            if (c[u] > w + c[v]) {
                my_assert(c[u] == INF);
                c[u] = w + c[v];
                p[u] = v;
    return {c, p};

int max_score(int n, int x, int y, ll k, vi U, vi V, vi W) {
    k *= 2;
    int ans = 2;
    vector<vpi> g(n);
    rep(i, U.size()) {
        int u = U[i];
        int v = V[i];
        int w = W[i];
        w *= 2;
        my_assert(w > 0);
        g[u].emplace_back(v, w);
        g[v].emplace_back(u, w);
    auto [cx, px] = BuildC(n, x, g);
    auto [cy, py] = BuildC(n, y, g);
        auto sx = cx, sy = cy;
        ll sum = accumulate(all(sy), 0ll);
        int j = n;
        for (int i = 0; i <= n; ++i) {
            if (i) sum += sx[i - 1];
            while (j > 0 && sum > k) {
                sum -= sy[j];
            if (sum > k) break;
            ans = max(ans, i + j);
    vector<int> mids;
        vector<int> path;
        for(int v = x; v != -1; v = py[v]) path.push_back(v);
        rep(i, path.size()) {
            int v = path[i];
            if (i + 1 < path.size() && cy[path[i + 1]] > cx[path[i + 1]]) continue;
            if (i > 0 && cx[path[i - 1]] > cy[path[i - 1]]) continue;
//    my_assert(mids.size() <= 4);
    for (auto mid: mids) {
//        if (true) {
//            my_assert(mid >= x && mid <= y);
//            int l = x;
//            int r = y;
//            while (true) {
//                ll s = 0;
//                for (int i = l; i <= r; ++i) s += min(cx[i], cy[i]);
//                s += abs(cx[mid] - cy[mid]);
//                if (s > k) break;
//                int score = r - l + 1 + 1;
//                vector<ll> can;
//                for (int i = l; i <= r; ++i) {
//                    if (i == mid) continue;
//                    can.push_back(abs(cx[i] - cy[i]));
//                }
//                ans = max(ans, score);
//                sort(all(can));
//                for (auto &v: can) {
//                    s += v;
//                    if (s > k) break;
//                    score++;
//                    ans = max(ans, score);
//                }
//                if (l == 0 && r == n - 1) break;
//                if (l == 0) {
//                    r++;
//                } else if (r == n - 1) {
//                    l--;
//                } else if (min(cx[l - 1], cy[l - 1]) > min(cx[r + 1], cy[r + 1])) {
//                    r++;
//                } else {
//                    l--;
//                }
//            }
//            continue;
//        }
        vector<int> state(n);
        ll use = max(cx[mid], cy[mid]);
        state[mid] = 2;
        for (int vx = px[mid]; vx != -1; vx = px[vx]) {
            my_assert(state[vx] == 0);
            my_assert(cx[vx] < cy[vx]);
            state[vx] = 1;
            use += cx[vx];
        for (int vy = py[mid]; vy != -1; vy = py[vy]) {
            my_assert(state[vy] == 0);
            my_assert(cx[vy] > cy[vy]);
            state[vy] = 1;
            use += cy[vy];
        if (use > k) continue;
        my_assert(state[x] && state[y]);
        int score_default = accumulate(all(state), 0);
        ans = max(ans, score_default);
        vector<ll> xx;
        vector<ll> diff(n), mx(n), mn(n);
        rep(i, n) {
            diff[i] = abs(cx[i] - cy[i]);
            mn[i] = min(cx[i], cy[i]);
            mx[i] = max(cx[i], cy[i]);
            my_assert(mx[i] % 2 == 0);
            xx.push_back(mx[i] / 2);
        xx.resize(unique(all(xx)) - xx.begin());
        auto GetInd = [&](ll x) {
            int i = lower_bound(all(xx), x) - xx.begin();
            my_assert(i < xx.size() && xx[i] == x);
            return i;
        int sz = xx.size();
        SegTree st1, st2, ct1, ct2;;;;;
        vi ord2;
        vi ord1;
        rep(i, n) {
            if (state[i] == 2) continue;
            if (state[i] == 1) {
                st1.upd(GetInd(diff[i]), diff[i]);
                ct1.upd(GetInd(diff[i]), 1);
//            used[i] = false;
//            if (mn[i] > diff[i]) {
//                st2.upd(GetInd(mx[i] / 2), mx[i]);
//                ct2.upd(GetInd(mx[i] / 2), 1);
//            }
//        rep(v, n) if (state[v]) {
//                for (auto &[u, w]: g[v]) {
//                    if (state[u] == 0 && !used[u]) {
//                        dfs(u);
//                    }
//                }
//            }
             [&](const int &i, const int &j) { return make_pair(mn[i], diff[i]) < make_pair(mn[j], diff[j]); });
             [&](const int &i, const int &j) { return diff[i] < diff[j]; });
        auto FindNext = [&](int l) {
//            for(int i = l; i < sz; ++i) {
//                if (ct1.get(i, i + 1)) return i;
//            }
//            return sz;
            int L = l - 1;
            int R = sz;
            while (R - L > 1) {
                int mid = (L + R) / 2;
                if (ct1.get(l, mid + 1)) {
                    R = mid;
                } else {
                    L = mid;
            my_assert(ct1.get(l, R) == 0);
            if (R != sz) my_assert(ct1.get(R, R + 1) > 0);
            return R;
        auto Check = [&]() {
            if (use > k) return;
//            vector<ll> val2, val1;
//            rep(i, n) {
//                if (state[i] == 1) {
//                    val1.push_back(diff[i]);
//                }
//                if (state[i] == 0) {
//                    val2.push_back(mx[i]);
//                }
//            }
//            int score = score_default;
//            sort(all(val1));
//            sort(all(val2));
//            int i = 0;
//            int j = 0;
//            ll have = k - use;
//            while (true) {
//                ans = max(ans, score);
//                ll to = INF;
//                if (i < val1.size() && val1[i] <= have) to = min(to, val1[i]);
//                if (j < val2.size() && val2[j] <= have) to = min(to, val2[j] / 2);
//                if (j < val2.size() && val2[j] <= have && to == val2[j] / 2) {
//                    have -= val2[j];
//                    score += 2;
//                    j++;
//                } else if (i < val1.size() && val1[i] <= have && to == val1[i]) {
//                    have -= val1[i];
//                    score++;
//                    i++;
//                } else {
//                    my_assert(to == INF);
//                    break;
//                }
//            }
//            return;
//            int l = -1;
//            int r = 0;
//            while (r < sz && use + st1.get(0, r + 1) + st2.get(0, r + 1) <= k) {
//                r++;
//                l++;
//            }
            int l = -1;
            int r = sz;
            while (r - l > 1) {
                int mid = (l + r) / 2;
                ll s = st1.get(0, mid + 1) + st2.get(0, mid + 1);
                if (s + use <= k) {
                    l = mid;
                } else {
                    r = mid;
            my_assert(l != -1);
            int score = score_default + ct1.get(0, r) + ct2.get(0, r) * 2;
            if (r == sz) {
//                my_assert(score == 2 * n);
                ans = max(ans, score);
            int cnt2 = ct2.get(r, r + 1);
            int cnt1 = ct1.get(r, r + 1);
            int nxt = FindNext(r + 1);
            int score2 = score;
            for(int t = 0; t <= 1; ++t) {
                score = score2;
                if (t && nxt == sz) continue;
                ll have = k - use - st1.get(0, r) - st2.get(0, r);
                if (t && xx[nxt] <= have) {
                    have -= xx[nxt];
                my_assert(have >= 0);
                if (cnt2) {
                    ll t2 = min(1ll * cnt2, have / (2 * xx[r]));
                    cnt2 -= t2;
                    have -= t2 * xx[r] * 2;
                    score += t2 * 2;

                if (cnt1) {
                    ll t1 = min(1ll * cnt1, have / xx[r]);
                    cnt1 -= t1;
                    have -= t1 * xx[r];
                    score += t1;
                my_assert(have >= 0);
                ans = max(ans, score);
        vector<bool> ok(n);
        int uk2 = 0;
        auto Check2 = [&] (ll value) {
            while (uk2 < ord2.size() && diff[ord2[uk2]] <= value) {
                int i = ord2[uk2++];
                if (state[i] == 0) {
                    ok[i] = true;
                    st2.upd(GetInd(mx[i] / 2), mx[i]);
                    ct2.upd(GetInd(mx[i] / 2), 1);
        for (auto &i: ord1) {
            Check2(mn[i] - 1);
            my_assert(state[i] == 0);
            state[i] = 1;
            use += mn[i];
            if (use > k) break;
            if (ok[i]) {
                st2.upd(GetInd(mx[i] / 2), -mx[i]);
                ct2.upd(GetInd(mx[i] / 2), -1);
            st1.upd(GetInd(diff[i]), diff[i]);
            ct1.upd(GetInd(diff[i]), 1);
    my_assert(ans <= 2 * n);
    return ans;

Compilation message

closing.cpp: In function 'int max_score(int, int, int, ll, vi, vi, vi)':
closing.cpp:4:36: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
    4 | #define rep(i, n) for(int i = 0; i < (n); ++i)
      |                                    ^
closing.cpp:78:5: note: in expansion of macro 'rep'
   78 |     rep(i, U.size()) {
      |     ^~~
closing.cpp:4:36: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
    4 | #define rep(i, n) for(int i = 0; i < (n); ++i)
      |                                    ^
closing.cpp:109:9: note: in expansion of macro 'rep'
  109 |         rep(i, path.size()) {
      |         ^~~
closing.cpp:111:23: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  111 |             if (i + 1 < path.size() && cy[path[i + 1]] > cx[path[i + 1]]) continue;
      |                 ~~~~~~^~~~~~~~~~~~~
closing.cpp: In lambda function:
closing.cpp:189:25: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<long long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  189 |             my_assert(i < xx.size() && xx[i] == x);
      |                       ~~^~~~~~~~~~~
closing.cpp: In lambda function:
closing.cpp:338:24: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  338 |             while (uk2 < ord2.size() && diff[ord2[uk2]] <= value) {
      |                    ~~~~^~~~~~~~~~~~~
# Verdict Execution time Memory Grader output
1 Execution timed out 1068 ms 348 KB Time limit exceeded
2 Halted 0 ms 0 KB -
# Verdict Execution time Memory Grader output
1 Correct 164 ms 35664 KB Output is correct
2 Correct 203 ms 34892 KB Output is correct
3 Correct 96 ms 3096 KB Output is correct
# Verdict Execution time Memory Grader output
1 Correct 0 ms 344 KB Output is correct
2 Correct 0 ms 348 KB Output is correct
3 Correct 0 ms 600 KB Output is correct
4 Correct 1 ms 348 KB Output is correct
5 Execution timed out 1076 ms 348 KB Time limit exceeded
6 Halted 0 ms 0 KB -
# Verdict Execution time Memory Grader output
1 Correct 0 ms 344 KB Output is correct
2 Correct 0 ms 348 KB Output is correct
3 Correct 0 ms 600 KB Output is correct
4 Correct 1 ms 348 KB Output is correct
5 Execution timed out 1076 ms 348 KB Time limit exceeded
6 Halted 0 ms 0 KB -
# Verdict Execution time Memory Grader output
1 Correct 0 ms 344 KB Output is correct
2 Correct 0 ms 348 KB Output is correct
3 Correct 0 ms 600 KB Output is correct
4 Correct 1 ms 348 KB Output is correct
5 Execution timed out 1076 ms 348 KB Time limit exceeded
6 Halted 0 ms 0 KB -
# Verdict Execution time Memory Grader output
1 Execution timed out 1068 ms 348 KB Time limit exceeded
2 Halted 0 ms 0 KB -
# Verdict Execution time Memory Grader output
1 Execution timed out 1068 ms 348 KB Time limit exceeded
2 Halted 0 ms 0 KB -
# Verdict Execution time Memory Grader output
1 Execution timed out 1068 ms 348 KB Time limit exceeded
2 Halted 0 ms 0 KB -
# Verdict Execution time Memory Grader output
1 Execution timed out 1068 ms 348 KB Time limit exceeded
2 Halted 0 ms 0 KB -
# Verdict Execution time Memory Grader output
1 Execution timed out 1068 ms 348 KB Time limit exceeded
2 Halted 0 ms 0 KB -