Submission #1347466

#TimeUsernameProblemLanguageResultExecution timeMemory
1347466MMihalevChorus (JOI23_chorus)C++20
61 / 100
4262 ms9316 KiB
#include<iostream>
#include<vector>
#include<algorithm>
#include<set>
#include<unordered_map>
#include<cmath>
using namespace std;

const int MAX_N=1e5+6;
const long long INF=(1LL<<62);

int p[MAX_N];
int pp[MAX_N];
int n,k;
string s;

struct line
{
    long double m,c;
    long long calc(long long x){return m*x+c;}
    long double intersect(line l)
    {
        return (long double)(c-l.c)/(l.m-m);
    }
};

vector<line>dp;
vector<int>bor,mincnt;

void add(line l,int cnt)
{
    while(dp.size()>=2 && l.intersect(dp.back())<=l.intersect(dp[dp.size()-2]))
    {
        dp.pop_back();
        bor.pop_back();
        mincnt.pop_back();
    }

    if(dp.size())
    {
        bor.back()=floor(l.intersect(dp.back()));
    }

    dp.push_back(l);
    bor.push_back(1000000000);
    mincnt.push_back(cnt);
}


multiset<pair<long long,int>>active;
vector<pair<int,int>>here[MAX_N];
long long dpp[MAX_N];
int ints[MAX_N];
int lastgroups;
long long lastdp;

line l;
int groups(long long penalty)
{
    active.clear();
    for(int i=0;i<=n;i++)here[i].clear();
    dp.clear();bor.clear();mincnt.clear();
    
    for(int i=1;i<=n;i++){dpp[i]=INF;}
    dpp[0]=0;
    ints[0]=0;

    for(int i=0;i<=n;i++)
    {
        for(auto&[j,id]:here[i])
        {
            active.erase(active.find({dpp[j],ints[j]}));
            
            l.m=-j;
            l.c=-pp[id-1]+id*j-j+dpp[j];
            add(l,ints[j]);
        }

        if(active.size())
        {
            long long top;
            int curints;
            tie(top,curints)=*active.begin();
            if(top<dpp[i])
            {
                dpp[i]=top;
                ints[i]=curints+1;
            }
            else if(top==dpp[i])
            {
                ints[i]=min(ints[i],curints+1);
            }
        }

        int x=i;
        int l=0,r=dp.size()-1;
        int mid;
        long long top;
        int curints;
        int pos;
        while(l<=r)
        {
            int mid=(l+r)/2;
            if(x<=bor[mid])
            {
                pos=mid;
                top=dp[mid].calc(x)+pp[i];
                curints=mincnt[mid];

                if(top<dpp[i])
                {
                    dpp[i]=top;
                    ints[i]=curints+1;
                }
                else if(top==dpp[i])
                {
                    ints[i]=min(ints[i],curints+1);
                }

                r=mid-1;
            }
            else l=mid+1;
        }

        for(int mid=max(0,pos-50);mid<min(pos+1000,(int)dp.size());mid++)
        {
            top=dp[mid].calc(x)+pp[i];
            curints=mincnt[mid];

            if(top<dpp[i])
            {
                dpp[i]=top;
                ints[i]=curints+1;
            }
            else if(top==dpp[i])
            {
                ints[i]=min(ints[i],curints+1);
            }
        }

        for(int mid=max(0,(int)dp.size()-100);mid<dp.size();mid++)
        {
            top=dp[mid].calc(x)+pp[i];
            curints=mincnt[mid];

            if(top<dpp[i])
            {
                dpp[i]=top;
                ints[i]=curints+1;
            }
            else if(top==dpp[i])
            {
                ints[i]=min(ints[i],curints+1);
            }
        }
    

        l=i+1;r=n;
        int id=-1;
        while(l<=r)
        {
            int mid=(l+r)/2;
            if(p[mid]>i)
            {
                id=mid;
                r=mid-1;
            }
            else l=mid+1;
        }

        if(id!=-1 && dpp[i]!=INF)
        {
            here[id].push_back({i,id});
        }

        if(dpp[i]!=INF){if(i){dpp[i]+=penalty;}active.insert({dpp[i],ints[i]});}

        
        lastdp=dpp[i];
        lastgroups=ints[i];
    }

    return lastgroups;
}

int main ()
{
    ios_base::sync_with_stdio(0);
    cin.tie(NULL);
    cout.tie(NULL);
    
    cin>>n>>k;
    cin>>s;

    int pos=0;
    int cnt=0;
    for(int i=0;i<2*n;i++)
    {
        if(s[i]=='A')
        {
            pos++;
            p[pos]=cnt;
        }
        else cnt++;
    }

    for(int i=1;i<=n;i++)
    {
        pp[i]=pp[i-1]+p[i];
    }
    
    long long l=0,r=+(1LL<<60);
    long long ans=-1;
    while(l<=r)
    {
        long long mid=(l+r)/2;
        if(groups(mid)<=k)
        {
            ans=lastdp-k*mid;
            r=mid-1;
        }
        else l=mid+1;
    }

    cout<<ans<<"\n";

    return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...