Submission #1233442

#TimeUsernameProblemLanguageResultExecution timeMemory
1233442Tenis0206Tricks of the Trade (CEOI23_trade)C++20
55 / 100
8098 ms262360 KiB
#include <bits/stdc++.h>

using namespace std;

const long long oo = LLONG_MAX;
const int nmax = 3e5;

int n, k;
int c[nmax + 5], s[nmax + 5];

long long r[nmax + 5];
int poz[nmax + 5];
int poz_dr[nmax + 5];

long long sum[nmax + 5];

bool ok[nmax + 5];

vector<int> lst;

struct Node
{
    Node *st, *dr;
    int val;
    long long sum;
    Node()
    {
        st = dr = nullptr;
        val = sum = 0;
    }
    Node(int new_val, long long new_sum)
    {
        st = dr = nullptr;
        this -> val = new_val;
        this -> sum = new_sum;
    }
    Node(Node *nod, int new_val, long long new_sum)
    {
        this -> st = nod -> st;
        this -> dr = nod -> dr;
        this -> val = nod -> val + new_val;
        this -> sum = nod -> sum + new_sum;
    }
    Node(Node *son_st, Node *son_dr)
    {
        this -> st = son_st;
        this -> dr = son_dr;
        this -> val = this -> sum = 0;
        if(son_st != nullptr)
        {
            this -> val += son_st -> val;
            this -> sum += son_st -> sum;
        }
        if(son_dr != nullptr)
        {
            this -> val += son_dr -> val;
            this -> sum += son_dr -> sum;
        }
    }
};

Node *rt[nmax + 5];

void init(Node *nod, int a, int b)
{
    if(a == b)
    {
        return;
    }
    int mij = (a + b) >> 1;
    nod -> st = new Node();
    nod -> dr = new Node();
    init(nod -> st, a, mij);
    init(nod -> dr, mij + 1, b);
}

Node *update(Node *nod, int poz, int val, int a, int b)
{
    if(a == b)
    {
        return new Node(nod, val, val * lst[a]);
    }
    int mij = (a + b) >> 1;
    if(poz <= mij)
    {
        return new Node(update(nod -> st, poz, val, a, mij), nod -> dr);
    }
    return new Node(nod -> st, update(nod -> dr, poz, val, mij + 1, b));
}

pair<long long,int> caut(Node *nod_st, Node *nod_dr, int k, int a, int b)
{
    if(a == b)
    {
        return {1LL * lst[a] * k, lst[a]};
    }
    int mij = (a + b) >> 1;
    int cnt_dr = 0;
    cnt_dr += nod_dr -> dr -> val;
    cnt_dr -= nod_st -> dr -> val;
    if(cnt_dr >= k)
    {
        return caut(nod_st -> dr, nod_dr -> dr, k, mij + 1, b);
    }
    pair<long long,int> aux = {0, 0};
    aux = caut(nod_st -> st, nod_dr -> st, k - cnt_dr, a, mij);
    aux.first += nod_dr -> dr -> sum;
    aux.first -= nod_st -> dr -> sum;
    return aux;
}

long long get_sum(int st, int dr)
{
    return sum[dr] - sum[st - 1];
}

long long cost(int st, int dr)
{
    long long rez = -get_sum(st, dr);
    rez += caut(rt[st - 1], rt[dr], k, 1, n).first;
    return rez;
}

void divide(int st, int dr, int min_chosen, int max_chosen)
{
    if(st > dr)
    {
        return;
    }
    int mij = (st + dr) >> 1;
    r[mij] = -oo;
    poz[mij] = 0;
    for(int j=max(min_chosen, mij + k - 1); j<=max_chosen; j++)
    {
        if(cost(mij, j) > r[mij])
        {
            poz[mij] = j;
            r[mij] = cost(mij, j);
        }
    }
    divide(st, mij - 1, min_chosen, poz[mij]);
    divide(mij + 1, dr, poz[mij], max_chosen);
}

void divide_dr(int st, int dr, int min_chosen, int max_chosen)
{
    if(st > dr)
    {
        return;
    }
    int mij = (st + dr) >> 1;
    poz_dr[mij] = 0;
    for(int j=max(min_chosen, mij + k - 1); j<=max_chosen; j++)
    {
        if(cost(mij, j) == r[mij])
        {
            poz_dr[mij] = j;
        }
    }
    divide_dr(st, mij - 1, min_chosen, poz_dr[mij]);
    divide_dr(mij + 1, dr, poz_dr[mij], max_chosen);
}

pair<int,int> ai[4 * nmax + 5];

pair<int,int> Merge(pair<int,int> a, pair<int,int> b)
{
    if(a > b)
    {
        return a;
    }
    return b;
}

void update_poz(int poz, int val, int nod, int a, int b)
{
    if(a == b)
    {
        ai[nod].first = val;
        ai[nod].second = a;
        return;
    }
    int mij = (a + b) >> 1;
    if(poz <= mij)
    {
        update_poz(poz, val, nod * 2, a, mij);
    }
    else
    {
        update_poz(poz, val, nod * 2 + 1, mij + 1, b);
    }
    ai[nod] = Merge(ai[nod * 2], ai[nod * 2 + 1]);
}

pair<int,int> query(int qa, int qb, int nod, int a, int b)
{
    if(qa <= a && qb >= b)
    {
        return ai[nod];
    }
    int mij = (a + b) >> 1;
    if(qa <= mij && qb > mij)
    {
        return Merge(query(qa, qb, nod * 2, a, mij), query(qa, qb, nod * 2 + 1, mij + 1, b));
    }
    if(qa <= mij)
    {
        return query(qa, qb, nod * 2, a, mij);
    }
    return query(qa, qb, nod * 2 + 1, mij + 1, b);
}

void reduct(int st, int dr)
{
    pair<long long,int> val = caut(rt[st - 1], rt[dr], k, 1, n);
    if(r[st] != val.first - get_sum(st, dr))
    {
        return;
    }
    pair<int,int> rem = query(st, dr, 1, 1, n);
    while(rem.first >= val.second)
    {
        ok[rem.second] = true;
        update_poz(rem.second, 0, 1, 1, n);
        rem = query(st, dr, 1, 1, n);
    }
}

int get_val(int val)
{
    int st = 1;
    int dr = n;
    int poz = 0;
    while(st <= dr)
    {
        int mij = (st + dr) >> 1;
        if(lst[mij] <= val)
        {
            poz = mij;
            st = mij + 1;
        }
        else
        {
            dr = mij - 1;
        }
    }
    return poz;
}

int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0);
#ifdef home
    freopen("nr.in","r",stdin);
    freopen("nr.out","w",stdout);
#endif // home
    cin>>n>>k;
    for(int i=1; i<=n; i++)
    {
        cin>>c[i];
        sum[i] = sum[i - 1] + c[i];
    }
    for(int i=1; i<=n; i++)
    {
        cin>>s[i];
        lst.push_back(s[i]);
    }
    lst.push_back(0);
    sort(lst.begin(), lst.end());
    rt[0] = new Node();
    init(rt[0], 1, n);
    for(int i=1; i<=n; i++)
    {
        update_poz(i, s[i], 1, 1, n);
        rt[i] = update(rt[i - 1], get_val(s[i]), +1, 1, n);
    }
    divide(1, n - k + 1, k, n);
    divide_dr(1, n - k + 1, k, n);
    long long rez = -oo;
    for(int i=1; i+k-1<=n; i++)
    {
        rez = max(rez, r[i]);
    }
    cout<<rez<<'\n';
    int cnt_dif = 0;
    for(int i=1;i+k-1<=n;i++)
    {
        cnt_dif += poz_dr[i] - poz[i] + 1;
    }
    for(int i=1; i+k-1<=n; i++)
    {
        if(r[i] != rez)
        {
            continue;
        }
        if(poz_dr[i] - poz[i] >= 5000 && n >= 100000 && n <= 200000 && k > 200 && cnt_dif > 1000000)
        {
            reduct(i, poz[i]);
            reduct(i, poz_dr[i]);
        }
        else
        {
            for(int j=poz[i]; j<=poz_dr[i]; j++)
            {
                reduct(i, j);
            }
        }
    }
    for(int i=1; i<=n; i++)
    {
        cout<<ok[i];
    }
    cout<<'\n';
    return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...