제출 #1280068

#제출 시각아이디문제언어결과실행 시간메모리
1280068andrei_iorgulescu팀들 (IOI15_teams)C++20
100 / 100
1525 ms358400 KiB
#include <bits/stdc++.h>
#include "teams.h"

using namespace std;

using integer = int;

struct DS
{
    int n;
    vector<pair<int, int>> a;
    vector<int> v;
    vector<vector<int>> aint;

    void init(int N, vector<pair<int, int>> A)
    {
        n = N;
        a = A;
        v.resize(n);
        sort(a.begin(), a.end());
        for (int i = 0; i < n; i++)
            v[i] = a[i].second;
        aint.resize(4 * n + 5);
        for (int i = 0; i < n; i++)
            aint[1].push_back(i);
        build(1, 1, n);
    }

    void build(int nod, int vl, int vr)
    {
        if (vl == vr)
            return;
        int mij = (vl + vr) / 2;
        for (auto it : aint[nod])
        {
            if (v[it] <= mij)
                aint[2 * nod].push_back(it);
            else
                aint[2 * nod + 1].push_back(it);
        }
        build(2 * nod, vl, mij);
        build(2 * nod + 1, mij + 1, vr);
    }

    int query(int nod, int vl, int vr, int stt, int drr, int k)
    {
        if (vl == vr)
            return vl;
        int mij = (vl + vr) / 2;
        int st = -1, pas = 1 << 19;
        while (pas != 0)
        {
            if (st + pas < aint[2 * nod].size() and aint[2 * nod][st + pas] <= drr)
                st += pas;
            pas /= 2;
        }
        int rr = st;
        st = aint[2 * nod].size(), pas = 1 << 19;
        while (pas != 0)
        {
            if (st - pas >= 0 and aint[2 * nod][st - pas] >= stt)
                st -= pas;
            pas /= 2;
        }
        int ll = st;
        if (rr - ll + 1 >= k)
            return query(2 * nod, vl, mij, stt, drr, k);
        else
            return query(2 * nod + 1, mij + 1, vr, stt, drr, k - max(0, (rr - ll + 1)));
    }

    pair<int, int> getgood(int l, int r)
    {
        int st = -1, pas = 1 << 19;
        while (pas != 0)
        {
            if (st + pas < n and a[st + pas].first <= r)
                st += pas;
            pas /= 2;
        }
        int rr = st;
        st = n, pas = 1 << 19;
        while (pas != 0)
        {
            if (st - pas >= 0 and a[st - pas].first >= l)
                st -= pas;
            pas /= 2;
        }
        int ll = st;
        return {ll, rr};
    }

    int query_kth(int l, int r, int k)
    {
        pair<int, int> lprp = getgood(l, r);
        l = lprp.first;
        r = lprp.second;
        if (r - l + 1 <= k)
            return n + 1;
        return query(1, 1, n, l, r, k);
    }
};

struct DS2
{
    struct node
    {
        int fs, fd;///as in id in nodes
        int val;///aint-ul
    };

    int n;
    vector<pair<int, int>> a;
    vector<vector<int>> f;
    vector<node> nodes;
    vector<int> roots;
    int rt;

    void build(int nod, int l, int r)
    {
        if (l == r)
        {
            nodes[nod].val = 0;
            nodes[nod].fs = nodes[nod].fd = -1;
            return;
        }
        int mij = (l + r) / 2;
        node aux;
        aux.val = 0;
        aux.fs = aux.fd = -1;
        nodes.push_back(aux);
        nodes[nod].fs = nodes.size() - 1;
        nodes.push_back(aux);
        nodes[nod].fd = nodes.size() - 1;
        build(nodes[nod].fs, l, mij);
        build(nodes[nod].fd, mij + 1, r);
    }

    void update(int nod, int l, int r, int pos)
    {
        if (l == r)
        {
            nodes[nod].val++;
            return;
        }
        else
        {
            int mij = (l + r) / 2;
            if (pos <= mij)
            {
                node aux = nodes[nodes[nod].fs];
                nodes.push_back(aux);
                nodes[nod].fs = nodes.size() - 1;
                update(nodes[nod].fs, l, mij, pos);
                nodes[nod].val++;
            }
            else
            {
                node aux = nodes[nodes[nod].fd];
                nodes.push_back(aux);
                nodes[nod].fd = nodes.size() - 1;
                update(nodes[nod].fd, mij + 1, r, pos);
                nodes[nod].val++;
            }
        }
    }

    void init(int N, vector<pair<int, int>> A)
    {
        n = N;
        a = A;
        f.resize(n + 1);
        for (auto it : a)
            f[it.second].push_back(it.first);
        roots.resize(n + 1);
        roots[0] = 0;
        node auxx;
        auxx.val = 0;
        auxx.fs = -1;
        auxx.fd = -1;
        nodes.push_back(auxx);
        build(0, 1, n);
        rt = 0;
        for (int i = 1; i <= n; i++)
        {
            for (auto it : f[i])
            {
                nodes.push_back(nodes[rt]);
                rt = nodes.size() - 1;
                update(rt, 1, n, it);
            }
            roots[i] = rt;
        }
    }

    int query(int nod, int l, int r, int st, int dr)
    {
        if (r < st or dr < l)
            return 0;
        if (st <= l and r <= dr)
            return nodes[nod].val;
        int mij = (l + r) / 2;
        return query(nodes[nod].fs, l, mij, st, dr) + query(nodes[nod].fd, mij + 1, r, st, dr);
    }

    int cate_incl(int l, int r)
    {
        return query(roots[r], 1, n, l, n);
    }
};

DS wavelet;
DS2 pst;
int n;

void init(integer N, integer A[], integer B[])
{
    n = N;
    vector<pair<int, int>> a;
    for (int i = 0; i < N; i++)
        a.push_back({A[i], B[i]});
    wavelet.init(n, a);
    pst.init(n, a);
}

integer can(integer M, integer K[])
{
    long long m = M, sm = 0;
    map<int, long long> mp;
    for (int i = 0; i < m; i++)
    {
        mp[K[i]] += K[i], sm += K[i];
    }
    if (sm > n)
        return 0;
	vector<pair<int, int>> a;
	for (auto it : mp)
        a.push_back(it);
    a.push_back({0, 0});
    sort(a.begin(), a.end());///probabil era deja sortat dar idk
    set<int> s;
    vector<int> dp(a.size());
    dp[0] = 0;
    s.insert(0);
    set<pair<int, int>> bang;
    int ans = n;
    for (int i = 1; i < a.size(); i++)
    {
        while (!bang.empty())
        {
            pair<int, int> it = *bang.begin();
            if (it.first >= a[i].first)
                break;
            bang.erase(it);
            int xx = *s.upper_bound(it.second);
            int zz = it.second;
            s.erase(xx);
            if (s.upper_bound(it.second) != s.end())
            {
                int yy = *s.upper_bound(it.second);
                bang.erase({wavelet.query_kth(a[xx].first + 1, a[yy].first, dp[yy] - dp[xx]), xx});
                bang.insert({wavelet.query_kth(a[zz].first + 1, a[yy].first, dp[yy] - dp[zz]), zz});
            }
        }
        int cn = *s.rbegin();
        dp[i] = a[i].second + dp[cn] + pst.cate_incl(a[cn].first + 1, a[i].first - 1);
        bang.insert({wavelet.query_kth(a[cn].first + 1, a[i].first, dp[i] - dp[cn]), cn});
        ans = max(ans, dp[i] + pst.cate_incl(a[i].first + 1, n));
        s.insert(i);
    }
    if (ans > n)
        return 0;
    else
        return 1;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...