Submission #896068

#TimeUsernameProblemLanguageResultExecution timeMemory
896068cadmiumskyChorus (JOI23_chorus)C++17
100 / 100
1179 ms75892 KiB
#include <bits/stdc++.h>
#define all(x) (x).begin(),(x).end()
using namespace std;

using ll = long long;
using ld = long double;

#define int ll
#define sz(x) ((int)(x).size())

using pii = pair<int,int>;
using tii = tuple<int,int,int>;

const int nmax = 1e6 + 5;
const ll inf = 1e9 + 5;

int B[nmax];


ll sum_under[nmax], cnt_under[nmax];
ll sum_pref[nmax];

struct IWCnt {
  ll val;
  ll cnt;
  IWCnt(): val(0), cnt(0) {;}
  IWCnt(ll a, ll b): val(a), cnt(b) {;}
  IWCnt operator + (const IWCnt& x) const { return IWCnt(val + x.val, cnt + x.cnt); }
  IWCnt operator + (const ll& x) const { return IWCnt(val + x, cnt); }
  IWCnt operator * (const ll& x) const { return IWCnt(val * x, cnt); }
};

namespace CHT {
  struct Line {
    ll m;
    IWCnt b;
    Line(ll a, IWCnt c): m(a), b(c) {;}
    ll operator()(const ll& x) const { return m * x + b.val; }
  };
  
  bool bad(Line second, Line last, Line nv) {
    return (nv.b.val - last.b.val) * (second.m - nv.m) < (nv.b.val - second.b.val) * (last.m - nv.m); 
  }
  
  vector<Line> st;
  int ptr;
  
  void push(Line a) {
    while(sz(st) > 1 && bad(rbegin(st)[1], rbegin(st)[0], a)) st.pop_back();
    st.emplace_back(a);
  }
  
  pii query(int P) {
    ptr = min(ptr, sz(st) - 1);
    while(ptr + 1 < sz(st) && (st[ptr](P) > st[ptr + 1](P) || (st[ptr](P) == st[ptr + 1](P) && st[ptr + 1].b.cnt > st[ptr].b.cnt))) // ca gen, aia mari sunt teoretic primii intalniti, si ar trebui sa ii aflu pe ei pt a ma asigura ca-s pe panta corecta
      ptr++;
      
    return pii{st[ptr](P), st[ptr].b.cnt};
  }
  
  void clear() {
    st.clear();
    ptr = 0;
  }
}

IWCnt dp[nmax];

int n, k;
bool check(ll lambda) {
  int last_nv = n;
    
  CHT::clear();
  
  dp[n + 1].val = 0;
  dp[n + 1].cnt = 0;
  
  for(int i = n; i > 0; i--) {
    
    while(last_nv >= B[i]) {
      CHT::push(CHT::Line(last_nv, dp[last_nv + 1] + (-sum_under[last_nv] + cnt_under[last_nv] * last_nv + last_nv)));
      last_nv--;
    }
    
    auto [C, tp] = CHT::query(-i);
    dp[i].val = C + sum_pref[i - 1] + lambda;
    dp[i].cnt = tp + 1;
  }
  
  //cerr << lambda << ' ' << dp[1].cnt << '\t' << dp[1].val << '\n';
  
  return dp[1].cnt >= k;
  
}

signed main() {
  cin.tie(0) -> sync_with_stdio(0);
  cin >> n >> k;
  
  int cnt[2] = {0, 0};
  char ch;
  for(int i = 0; i < 2 * n; i++) {
    cin >> ch;
    if(ch == 'A')
      cnt[0]++;
    else
      B[++cnt[1]] = cnt[0];
      //cerr << cnt[0] << '\n';
  }
  
  ll fixing = 0;
  for(int i = 1; i <= n; i++) {
    int target = max(B[i - 1], i);
    fixing += max(0LL, target - B[i]);
    B[i] += max(0LL, target - B[i]);
  }
  
  for(int i = 1; i <= n; i++)
    cnt_under[B[i]]++,
    sum_under[B[i]] += B[i],
    sum_pref[i] = sum_pref[i - 1] + B[i];
  
  for(int i = 1; i <= n; i++)
    cnt_under[i] += cnt_under[i - 1],
    sum_under[i] += sum_under[i - 1];
  
  //for(int i = 1; i <= n; i++)
    //cout << B[i] << ' ';
  //cout << '\n';
  
  ll lambda = 0;
  for(int i = 40; i >= 0; i--) {
    if(check(lambda + (1LL << i)))
      lambda += (1LL << i);
  }
  
  check(lambda);
  
  
  cout << dp[1].val - lambda * k + fixing << '\n';
}

#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...