#include <bits/stdc++.h>
using namespace std;
using ll = long long;
using pli = pair<ll, int>;
#define pb push_back
#define ff first
#define ss second
const ll infm = -1e18;
struct PST{
    struct node{
        node *l, *r;
        ll s;
        int c;
        node(node *ls, node *rs){
            l = ls; r = rs;
            s = c = 0;
            if (l){
                s += l -> s;
                c += l -> c;
            }
            if (r){
                s += r -> s;
                c += r -> c;
            }
        }
        node(ll s1, int c1){
            l = r = 0;
            s = s1; c = c1;
        }
    };
    vector<node*> root;
    vector<int> a;
    int n, cc;
    PST(int ns){
        n = ns;
        root.resize(n + 1);
        a.resize(n + 1);
        root[0] = build(1, n);
        cc = 0;
    }
    node* build(int tl, int tr){
        if (tl == tr) return new node(0, 0);
        int tm = (tl + tr) / 2;
        return new node(build(tl, tm), build(tm + 1, tr));
    }
    node* upd(node *v, int tl, int tr, int& p, int& x){
        if (tl == tr) return new node(v -> s + x, v -> c + 1);
        int tm = (tl + tr) / 2;
        if (p <= tm){
            if (!(v -> l)) v -> l = new node(0, 0);
            return new node(upd(v -> l, tl, tm, p, x), v -> r);
        }
        else {
            if (!(v -> r)) v -> r = new node(0, 0);
            return new node(v -> l, upd(v -> r, tm + 1, tr, p, x));
        }
    }
    void upd(int p, int x){
        a[p] = x;
        cc++;
        root[cc] = upd(root[cc - 1], 1, n, p, x);
    }
    int find(node *v1, node *v2, int tl, int tr, int k){
        if (tl == tr) return tl;
        int tm = (tl + tr) / 2, lf = (v2 -> l -> c) - (v1 -> l -> c);
        if (lf < k){
            return find(v1 -> r, v2 -> r, tm + 1, tr, k - lf);
        }
        return find(v1 -> l, v2 -> l, tl, tm, k);
    }
    pli sum(node *v, int tl, int tr, int& l, int& r){
        if (l > tr || r < tl) return {0, 0};
        if (l <= tl && tr <= r) return {v -> s, v -> c};
        int tm = (tl + tr) / 2;
        pli x = sum(v -> l, tl, tm, l, r), y = sum(v -> r, tm + 1, tr, l, r);
        return {x.ff + y.ff, x.ss + y.ss};
    }
    pli sum(int v, int l, int r){
        return sum(root[v], 1, n, l, r);
    }
    ll get(int l, int r, int k){
        int t = find(root[l - 1], root[r], 1, n, k);
        pli x = sum(l - 1, 1, t);
        pli y = sum(r, 1, t);
        
        k -= (y.ss - x.ss);
        ll out = y.ff - x.ff;
        if (k > 0) out += 1LL * k * a[t + 1];
        return out;
    }
};
int main(){
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    int n, k; cin>>n>>k;
    vector<int> a(n + 1);
    vector<ll> p(n + 1);
    for (int i = 1; i <= n; i++){
        cin>>a[i];
        p[i] = p[i - 1] + a[i];
    }
    vector<int> b(n + 1), all = {0};
    for (int i = 1; i <= n; i++){
        cin>>b[i];
        all.pb(b[i]);
    }
    
    PST T(n);
    sort(all.begin(), all.end());
    vector<int> :: iterator it;
    
    for (int i = 1; i <= n; i++){
        it = lower_bound(all.begin(), all.end(), b[i]);
        int j = (int)(it - all.begin());
        T.upd(j, b[i]);
    }
    auto f = [&](int l, int r){
        if ((r - l + 1) < k) return infm;
        return T.get(l, r, k) - (p[r] - p[l - 1]);
    };
    
    ll out = infm;
    function<void(int, int, int, int)> solve = [&](int l, int r, int l1, int r1){
        if (l > r) return;
        int m = (l + r) / 2;
        
        pli opt = {infm, 0};
        
        for (int i = max(m, l1); i <= r1; i++){
            opt = max(opt, {f(m, i), i});
        }
        
        out = max(out, opt.ff);
        solve(l, m - 1, l1, opt.ss);
        solve(m + 1, r, opt.ss, r1);
    };
    solve(1, n, 1, n);
    
    cout<<out<<"\n";
}
| # | Verdict  | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict  | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict  | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict  | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict  | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict  | Execution time | Memory | Grader output | 
|---|
| Fetching results... |