제출 #330966

#제출 시각아이디문제언어결과실행 시간메모리
33096612tqianCake 3 (JOI19_cake3)C++17
100 / 100
1149 ms150636 KiB
#include<bits/stdc++.h>
using namespace std;

typedef long long ll;
const int M = 5e6;
const int N = 2e5 + 5;
struct Node {
    int lc;
    int rc;
    ll sum = 0;
    int num = 0;
} t[M];

int cnt = 0;
int ti[N];
ll v[N];
ll c[N];
ll ans[N];
ll ord[N];

int n, m;

int modify(int p, int l, int r, int x, int v) {
    int u = ++cnt;
    if (l == r) {
        t[u].sum = t[p].sum + v;
        t[u].num = t[p].num + 1;
    } else {
        int m = (l + r) / 2;
        if (x <= m) {
            t[u].lc = modify(t[p].lc, l, m, x, v);
            t[u].rc = t[p].rc;
        } else {
            t[u].lc = t[p].lc;
            t[u].rc = modify(t[p].rc, m + 1, r, x, v);
        }
        t[u].sum = t[t[u].lc].sum + t[t[u].rc].sum;
        t[u].num = t[t[u].lc].num + t[t[u].rc].num;
    }
    return u;
}
// make sure it's stuff for l - 1, r
ll query(int num, int pl, int pr, int l, int r) {
    // need the highest num in the range [l, r]
    if (num == 0) 
        return 0;
    if (l == r) {
        return 1LL * ord[l] * num;
    }
    int mid = (l + r) / 2;
    ll rsum = t[t[pr].rc].sum - t[t[pl].rc].sum;
    int rnum = t[t[pr].rc].num - t[t[pl].rc].num;
    if (rnum >= num) 
        return query(num, t[pl].rc, t[pr].rc, mid + 1, r);
    else 
        return rsum + query(num - rnum, t[pl].lc, t[pr].lc, l, mid);
}
ll evaluate(int num, int l, int r) {
    ll res = query(num , ti[l - 1], ti[r], 1, n);
    return res;
}
ll f(int l, int r) {
    return evaluate(m, l, r) - 2 * (c[r] - c[l]);
}
void dnc(int l, int r, int gl, int gr) {
    if (l > r) return;
    int mid = (l + r) / 2;
    int best = gl;
    for (int i = max(mid + m - 1, gl); i <= gr; i++) {
        ll cur = f(mid, i);
        assert(i - mid + 1 >= m);
        if (cur > ans[mid])
            ans[mid] = cur, best = i;
    }
    dnc(l, mid - 1, gl, best);
    dnc(mid + 1, r, best, gr);
}
int main() {
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cin >> n >> m;
    for (int i = 1; i <= n; i++) 
        ans[i] = -2e18;
    vector<array<ll, 2>> parts(n);
    for (int i = 0; i < n; i++)
        cin >> parts[i][0] >> parts[i][1];
    sort(parts.begin(), parts.end(), [](array<ll, 2> a, array<ll, 2> b) {
        return a[1] < b[1];
    });
    for (int i = 1; i <= n; i++) {
        v[i] = parts[i - 1][0];
        c[i] = parts[i - 1][1];
    }  
    set<int> vals;
    for (int i = 1; i <= n; i++) 
        vals.insert(v[i]);
    map<int, int> conv;
    int num = 1;
    for (int x : vals) 
        conv[x] = num, ord[num] = x, num++;
    for (int i = 1; i <= n; i++) {
        ti[i] = modify(ti[i - 1], 1, n, conv[v[i]], v[i]);
    }
    dnc(1, n - m + 1, 1, n);
    ll res = -2e18  ;
    for (int i = 1; i <= n; i++)
        res = max(res, ans[i]);
    cout << res << '\n';
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...