#include <bits/stdc++.h>
#define int long long
using namespace std;
struct DP {
    long long minSwaps;
    int minGroups, maxGroups;
    DP operator + ( DP x ) {
        return { minSwaps + x.minSwaps, minGroups + x.minGroups, maxGroups + x.maxGroups };
    }
    DP operator * ( DP x ) {
        if ( minSwaps < x.minSwaps )
            return *this;
        if ( minSwaps > x.minSwaps )
            return x;
        return { minSwaps, min( minGroups, x.minGroups ), max( maxGroups, x.maxGroups ) };
    }
};
const int MAX_N = 1e6;
const long long INF = 1e18;
int a[MAX_N + 1], countLowerA[MAX_N + 1];
long long sumA[MAX_N + 1], sumLowerA[MAX_N + 1];
DP dp[MAX_N + 1];
struct func {
    long long a, b;
    int valmin, valmax;
    long long value( long long x ) {
        return a * x + b;
    }
};
struct CHT {
    vector<func> funcs;
    vector<long long> start;
    vector<int> stack;
    DP dppp[MAX_N + 1];
    void init() {
        for ( int i = 0; i <= MAX_N; i++ )
            dppp[i] = { INF, MAX_N + 1, 0 };
        funcs.clear();
        start.clear();
        stack.clear();
    }
    long double intersection( func f1, func f2 ) {
        return ((long double)f2.b - f1.b) / (f1.a - f2.a);
    }
    void addFunction( func f ) {
        int i = funcs.size();
        funcs.push_back( f );
        start.push_back( 0 );
        while ( !stack.empty() && intersection( funcs[i], funcs[stack.back()] ) <= start[stack.back()] )
            stack.pop_back();
        if ( !stack.empty() ) {
            long long x = intersection( funcs[i], funcs[stack.back()] );
            if ( intersection( funcs[i], funcs[stack.back()] ) == x && x <= MAX_N )
                dppp[x] = dppp[x] * DP{ funcs[stack.back()].value( x ), funcs[stack.back()].valmin, funcs[stack.back()].valmax };
            start[i] = ceil( intersection( funcs[i], funcs[stack.back()]) );
        }
        stack.push_back( i );
    }
    DP getMin( long long x ) {
        int l = -1, r = stack.size();
        while ( r - l > 1 ) {
            int mid = (l + r) / 2;
            if ( start[stack[mid]] > x )
                r = mid;
            else
                l = mid;
        }
        return DP{ funcs[stack[l]].value( x ), funcs[stack[l]].valmin, funcs[stack[l]].valmax } * dppp[x];
    }
} swapFunctions;
void computeDP( int n, long long cost ) {
    swapFunctions.init();
    swapFunctions.addFunction( { 0, 0, 0, 0 } );
    for ( int i = 1; i <= n; i++ ) {
        dp[i] = swapFunctions.getMin( i ) + DP{ i * countLowerA[i] - sumLowerA[i] + cost, 1, 1 };
        swapFunctions.addFunction( func{ -i, dp[i].minSwaps + sumA[i], dp[i].minGroups, dp[i].maxGroups } );
    }
}
signed main() {
    int n, k, swaps;
    cin >> n >> k;
    int countB = 0, countA = 0;
    swaps = 0;
    for ( int i = 0; i < 2 * n; i++ ) {
        char ch;
        cin >> ch;
        if ( ch == 'A' )
            countA++;
        else {
            countB++;
            swaps += max( 0LL, countB - countA );
            a[countB] = max( countA, countB );
            sumA[countB] = sumA[countB - 1] + a[countB];
            countLowerA[a[countB]]++;
            sumLowerA[a[countB]] += a[countB];
        }
    }
    for ( int i = 1; i <= n; i++ ) {
        countLowerA[i] += countLowerA[i - 1];
        sumLowerA[i] += sumLowerA[i - 1];
    }
    long long leftCost = -1, rightCost = 1e12;
    while ( rightCost - leftCost > 1 ) {
        long long cost = (leftCost + rightCost) / 2;
        computeDP( n, cost );
        if ( dp[n].minGroups > k )
            leftCost = cost;
        else
            rightCost = cost;
    }
    computeDP( n, rightCost );
    cout << swaps + dp[n].minSwaps - k * rightCost;
    return 0;
}
| # | Verdict  | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict  | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict  | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict  | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict  | Execution time | Memory | Grader output | 
|---|
| Fetching results... |