제출 #1233380

#제출 시각아이디문제언어결과실행 시간메모리
1233380Tenis0206Tricks of the Trade (CEOI23_trade)C++20
50 / 100
3399 ms376160 KiB
#include <bits/stdc++.h>
#define int long long

using namespace std;

const int oo = LLONG_MAX;
const int nmax = 3e5;
const int vmax = 1e9;

int n, k;
int c[nmax + 5], s[nmax + 5];

int r[nmax + 5], poz[nmax + 5];

int sum[nmax + 5];

struct Node
{
    Node *st, *dr;
    int val, sum;
    Node()
    {
        st = dr = nullptr;
        val = sum = 0;
    }
    Node(int new_val, int new_sum)
    {
        st = dr = nullptr;
        this -> val = new_val;
        this -> sum = new_sum;
    }
    Node(Node *nod, int new_val, int new_sum)
    {
        this -> st = nod -> st;
        this -> dr = nod -> dr;
        this -> val = nod -> val + new_val;
        this -> sum = nod -> sum + new_sum;
    }
    Node(Node *son_st, Node *son_dr)
    {
        this -> st = son_st;
        this -> dr = son_dr;
        this -> val = this -> sum = 0;
        if(son_st != nullptr)
        {
            this -> val += son_st -> val;
            this -> sum += son_st -> sum;
        }
        if(son_dr != nullptr)
        {
            this -> val += son_dr -> val;
            this -> sum += son_dr -> sum;
        }
    }
};

Node *rt[nmax + 5];

Node *update(Node *nod, int poz, int val, int a, int b)
{
    if(a == b)
    {
        if(nod == nullptr)
        {
            return new Node(val, val * a);
        }
        return new Node(nod, val, val * a);
    }
    int mij = (a + b) >> 1;
    if(poz <= mij)
    {
        if(nod == nullptr)
        {
            return new Node(update(nullptr, poz, val, a, mij), nullptr);
        }
        return new Node(update(nod -> st, poz, val, a, mij), nod -> dr);
    }
    if(nod == nullptr)
    {
        return new Node(nullptr, update(nullptr, poz, val, mij + 1, b));
    }
    return new Node(nod -> st, update(nod -> dr, poz, val, mij + 1, b));
}

pair<int,int> query(Node *nod_st, Node *nod_dr, int k, int a, int b)
{
    if(a == b)
    {
        return {a * k, a};
    }
    int mij = (a + b) >> 1;
    int cnt_dr = 0;
    if(nod_dr -> dr != nullptr)
    {
        cnt_dr += nod_dr -> dr -> val;
    }
    if(nod_st != nullptr && nod_st -> dr != nullptr)
    {
        cnt_dr -= nod_st -> dr -> val;
    }
    if(cnt_dr >= k)
    {
        if(nod_st == nullptr)
        {
            return query(nullptr, nod_dr -> dr, k, mij + 1, b);
        }
        return query(nod_st -> dr, nod_dr -> dr, k, mij + 1, b);
    }
    pair<int,int> aux = {0, 0};
    if(nod_st == nullptr)
    {
        aux = query(nullptr, nod_dr -> st, k - cnt_dr, a, mij);
    }
    else
    {
        aux = query(nod_st -> st, nod_dr -> st, k - cnt_dr, a, mij);
    }
    if(nod_dr -> dr != nullptr)
    {
        aux.first += nod_dr -> dr -> sum;
    }
    if(nod_st != nullptr && nod_st -> dr != nullptr)
    {
        aux.first -= nod_st -> dr -> sum;
    }
    return aux;
}

int get_sum(int st, int dr)
{
    return sum[dr] - sum[st - 1];
}

int cost(int st, int dr)
{
    int rez = -get_sum(st, dr);
    pair<int,int> val = query(rt[st - 1], rt[dr], k, 1, vmax);
    rez += val.first;
    return rez;
}

void divide(int st, int dr, int min_chosen, int max_chosen)
{
    if(st > dr)
    {
        return;
    }
    int mij = (st + dr) >> 1;
    r[mij] = -oo;
    poz[mij] = 0;
    for(int j=max(min_chosen, mij + k - 1);j<=max_chosen;j++)
    {
        if(cost(mij, j) > r[mij])
        {
            poz[mij] = j;
            r[mij] = cost(mij, j);
        }
    }
    divide(st, mij - 1, min_chosen, poz[mij]);
    divide(mij + 1, dr, poz[mij], max_chosen);
}

signed main()
{
    ios::sync_with_stdio(false);
    cin.tie(0);
    #ifdef home
    freopen("nr.in","r",stdin);
    freopen("nr.out","w",stdout);
    #endif // home
    cin>>n>>k;
    for(int i=1;i<=n;i++)
    {
        cin>>c[i];
        sum[i] = sum[i - 1] + c[i];
    }
    rt[0] = new Node();
    for(int i=1;i<=n;i++)
    {
        cin>>s[i];
        rt[i] = update(rt[i - 1], s[i], +1, 1, vmax);
    }
    divide(1, n - k + 1, k, n);
    int rez = -oo;
    for(int i=1;i+k-1<=n;i++)
    {
        rez = max(rez, r[i]);
    }
    cout<<rez<<'\n';
    return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...