제출 #210614

#제출 시각아이디문제언어결과실행 시간메모리
210614triCake 3 (JOI19_cake3)C++14
24 / 100
4038 ms199012 KiB
#include <bits/stdc++.h>

using namespace std;

typedef long long ll;
typedef long double ld;
typedef pair<int, int> pi;
typedef pair<ll, ll> pl;

typedef vector<int> vi;
typedef vector<ld> vd;
typedef vector<ll> vl;

#define pb push_back
#define f first
#define s second

namespace debug {
    const int DEBUG = true;

    template<class T1, class T2>
    void pr(const pair<T1, T2> &x);

    template<class T, size_t SZ>
    void pr(const array<T, SZ> &x);

    template<class T>
    void pr(const vector<T> &x);

    template<class T>
    void pr(const set<T> &x);

    template<class T1, class T2>
    void pr(const map<T1, T2> &x);

    template<class T>
    void pr(const T &x) { if (DEBUG) cout << x; }

    template<class T, class... Ts>
    void pr(const T &first, const Ts &... rest) { pr(first), pr(rest...); }

    template<class T1, class T2>
    void pr(const pair<T1, T2> &x) { pr("{", x.f, ", ", x.s, "}"); }

    template<class T>
    void prIn(const T &x) {
        pr("{");
        bool fst = 1;
        for (auto &a : x) {
            pr(fst ? "" : ", ", a), fst = 0;
        }
        pr("}");
    }

    template<class T, size_t SZ>
    void pr(const array<T, SZ> &x) { prIn(x); }

    template<class T>
    void pr(const vector<T> &x) { prIn(x); }

    template<class T>
    void pr(const set<T> &x) { prIn(x); }

    template<class T1, class T2>
    void pr(const map<T1, T2> &x) { prIn(x); }

    void ps() { pr("\n"), cout << flush; }

    template<class Arg, class... Args>
    void ps(const Arg &first, const Args &... rest) {
        pr(first, " ");
        ps(rest...);
    }
}
using namespace debug;


const int MAXN = 2e5 + 100;
const int LOGN = 20;
const ll INF = 1e15;
int N, K;

ll v[MAXN], c[MAXN];

ll bTotal = -INF;

set<pl> sets[LOGN];
set<pl> ext[LOGN];
pi sRange[LOGN];
ll sSum[LOGN];

void reRange(int setI, pi nRange) {
//    ps(setI, nRange);
    set<pl> &cSet = sets[setI];
    set<pl> &cExt = ext[setI];
    pi &cRange = sRange[setI];
    ll &cSum = sSum[setI];

    assert(cRange.f <= nRange.f && cRange.s <= nRange.s);
//    ps(cRange, nRange);

    while (cRange.s < nRange.s) {
        cExt.insert({v[cRange.s + 1], cRange.s + 1});
        cRange.s++;
    }

    while (cRange.f < nRange.f) {
//        ps(cRange.f);
        if (cSet.erase({v[cRange.f], cRange.f})) {
            cSum -= v[cRange.f];
        } else {
            int act = cExt.erase({v[cRange.f], cRange.f});
            assert(act);
        }
        cRange.f++;
    }

    while (cSet.size() < K && cExt.size()) {
        cSum += (--cExt.end())->f;

        cSet.insert(*(--cExt.end()));
        cExt.erase(--cExt.end());
    }

    while (cExt.size() && cSet.begin()->f < (--cExt.end())->f) {
        cSum -= cSet.begin()->f;
        cSum += (--cExt.end())->f;

        cExt.insert(*cSet.begin());
        cSet.erase(cSet.begin());

        cSet.insert(*(--cExt.end()));
        cExt.erase(--cExt.end());
    }

    assert(cRange == nRange);
    assert(cSet.size() == min(K, cRange.s - cRange.f + 1));
}

ll getTotal(int setI) {
    set<pl> &cSet = sets[setI];
    assert(cSet.size() <= K);
    if (cSet.size() == K) {
        return sSum[setI];
    } else {
        return -INF;
    }
}

void compute(int level, int l1, int l2, int r1, int r2) {
    int cL = (l1 + l2) / 2;

    ll maxTotal = -INF;
    int maxR = -1;

    assert(0 <= l1 && 0 <= r1);
    for (int cR = max(cL, r1); cR <= r2; cR++) {
        reRange(level, {cL, cR});
        ll cTotal = getTotal(level) + c[cL] - c[cR];
        if (cTotal > maxTotal) {
            maxTotal = cTotal;
            maxR = cR;
        }
//
//        if (cL == 3 && cR == 6) {
//            ps("test");
//            ps(sRange[level]);
//            ps(sets[level].size());
//            ps(getTotal(level));
//            ps(c[cL] + c[cR]);
//            ps(cTotal);
//        }
    }
//    ps(cL, r1, r2);
    assert(maxR != -1);
    bTotal = max(bTotal, maxTotal);

    if (l1 < cL) {
        compute(level + 1, l1, cL - 1, r1, maxR);
    }
    if (cL < l2) {
        compute(level + 1, cL + 1, l2, maxR, r2);
    }
}


int main() {
    cin >> N >> K;
    vector<pi> items;
    for (int i = 0; i < N; i++) {
        int x, y;
        cin >> x >> y;
        items.pb({y, x});
    }
    sort(items.begin(), items.end());

    for (int i = 0; i < N; i++) {
        v[i] = items[i].s;
        c[i] = items[i].f;
        c[i] *= 2;
    }
//
//    for (int i = 0; i < N; i++) {
//        ps(v[i], c[i]);
//    }

    fill(sSum, sSum + LOGN, 0);
    fill(sRange, sRange + LOGN, make_pair(0, -1));

    compute(0, 0, N - K, 0, N - 1);
    cout << bTotal << endl;
}

컴파일 시 표준 에러 (stderr) 메시지

cake3.cpp: In function 'void reRange(int, pi)':
cake3.cpp:118:24: warning: comparison between signed and unsigned integer expressions [-Wsign-compare]
     while (cSet.size() < K && cExt.size()) {
            ~~~~~~~~~~~~^~~
In file included from /usr/include/c++/7/cassert:44:0,
                 from /usr/include/x86_64-linux-gnu/c++/7/bits/stdc++.h:33,
                 from cake3.cpp:1:
cake3.cpp:137:24: warning: comparison between signed and unsigned integer expressions [-Wsign-compare]
     assert(cSet.size() == min(K, cRange.s - cRange.f + 1));
            ~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
cake3.cpp: In function 'll getTotal(int)':
cake3.cpp:142:24: warning: comparison between signed and unsigned integer expressions [-Wsign-compare]
     assert(cSet.size() <= K);
            ~~~~~~~~~~~~^~~~
cake3.cpp:143:21: warning: comparison between signed and unsigned integer expressions [-Wsign-compare]
     if (cSet.size() == K) {
         ~~~~~~~~~~~~^~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...