This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include<bits/stdc++.h>
using namespace std;
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
typedef long long ll;
#define int ll
typedef unsigned long long ull;
typedef long double ld;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
#define pb push_back
#define all(x) x.begin(), x.end()
#define sz(x) (int)x.size()
#define mispertion ios_base::sync_with_stdio(0), cin.tie(0), cout.tie(0)
#define F first
#define S second
#define getlast(s) (*s.rbegin())
#define debg cout << "OK\n"
const ld PI = 3.1415926535;
const int N = 1e6+5;
const int M = 50 + 1;
const int mod = 1e9+7;
const int infi = 1e15;
const ll infl = LLONG_MAX;
const int P = 31;
int mult(int a, int b) {
return a * 1LL * b % mod;
}
int sum(int a, int b) {
if (a + b < 0)
return a + b + mod;
if (a + b >= mod)
return a + b - mod;
return a + b;
}
ll binpow(ll a, ll n) {
if (n == 0)
return 1;
if (n % 2 == 1) {
return binpow(a, n - 1) * a % mod;
} else {
ll b = binpow(a, n / 2);
return b * b % mod;
}
}
int n, k, ans, ps[2 * N], R[2 * N], prp[2 * N], prc[2 * N];
string s;
void normalize(){
vector<int> psa = {}, psb = {};
for(int i = 1; i <= 2 * n; i++){
if(s[i] == 'A')
psa.pb(i);
else
psb.pb(i);
}
string ns = "#";
int balance = 0, skp = 0, cur = 0;
for(int i = 1; i <= 2 * n; i++){
if(s[i] == 'A' && skp){
skp--;
continue;
}
if(s[i] == 'B' && balance == 0){
ns += "A";
ans += (psa[cur] - sz(ns) + 1);
balance++;
skp++;
i--;
cur++;
continue;
}
if(s[i] == 'B'){
ns += "B";
balance--;
}else{
ns += "A";
cur++;
balance++;
}
}
s = ns;
}
struct Line{
ld k, b;
};
ld intersect(Line x, Line y){
return (y.b - x.b) / (x.k - y.k);
}
struct convex{
vector<pair<Line, int>> v;
int cur = 0;
void add(Line l, int a){
while(sz(v) > 1 && make_pair(intersect(v[sz(v) - 2].F, v.back().F), v.back().S) >= make_pair(intersect(v.back().F, l), a)){
if(sz(v) - 1 == cur)
cur--;
v.pop_back();
}
v.push_back({l, a});
}
pii get(int x){
if(sz(v) == 0){
return {infi, infi};
}
while(cur < sz(v) - 1 && intersect(v[cur].F, v[cur + 1].F) <= x)
cur++;
if(intersect(v[cur].F, v[cur - 1].F) == x){
return make_pair(v[cur].F.k * x + v[cur].F.b, min(v[cur].S, v[cur - 1].S));
}
return make_pair(v[cur].F.k * x + v[cur].F.b, v[cur].S);
}
};
pii getans(int C){
vector<pair<int, int>> dp(n + 2, {infi, infi});
dp[n + 1] = {0, 0};
convex cht;
vector<pair<pii, int>> evs1(n);
vector<pair<pii, int>> evs2(n);
for(int i = n; i >= 1; i--){
evs1[n - i] = {{R[i], i}, 1};
evs2[n - i] = {{ps[i], i + 1}, 2};
}
vector<pair<pii, int>> evs(2 * n);
int l = 0, r = 0;
while(l < sz(evs1) || r < sz(evs2)){
if(l == sz(evs1)){
evs[l + r] = evs2[r];
r++;
}else if(r == sz(evs2)){
evs[l + r] = evs1[l];
l++;
}else if(evs1[l] > evs2[r]){
evs[l + r] = evs1[l];
l++;
}else{
evs[l + r] = evs2[r];
r++;
}
}
int cj = n + 1;
for(auto e : evs){
int i = e.F.S;
if(e.S == 1){
int nzm = prp[R[i] - 1] - (prc[R[i] - 1] * prc[R[i] - 1] + prc[R[i] - 1]) / 2;
dp[i] = cht.get(-prc[R[i] - 1]);
dp[i].S++;
dp[i].F += nzm;
}else{
if(i <= n){
while(cj > 1 && R[i] < ps[cj - 1])
cj--;
if(cj > i)
dp[i] = min(dp[i], {dp[cj].F, dp[cj].S + 1});
dp[i].F += C;
}
ld b = dp[i].F + prc[ps[i - 1]] * ps[i - 1] - prp[ps[i - 1]] - (prc[ps[i - 1]] * prc[ps[i - 1]] - prc[ps[i - 1]]) / 2;
ld k = ps[i - 1] - prc[ps[i - 1]];
cht.add({k, b}, dp[i].S);
}
}
while(cj > 1 && R[1] < ps[cj - 1])
cj--;
if(cj > 1)
dp[1] = min(dp[1], {dp[cj].F, dp[cj].S + 1});
dp[1].F += C;
return dp[1];
}
void solve(){
cin >> n >> k;
cin >> s;
s = "#" + s;
normalize();
queue<int> st;
for(int i = 1; i <= 2 * n; i++){
prc[i] = prc[i - 1] + (s[i] == 'B');
prp[i] = prp[i - 1] + (s[i] == 'B') * i;
}
int cnt = 0;
for(int i = 2 * n; i >= 1; i--){
if(s[i] == 'B'){
st.push(i);
continue;
}
R[n - cnt] = st.front();
ps[n - cnt] = i;
st.pop();
cnt++;
}
int lo = -1, hi = 10 * n * n + 1;
while(lo + 1 < hi){
int C = (lo + hi) / 2;
auto e = getans(C);
if(e.S > k)
lo = C;
else
hi = C;
}
int C = hi;
auto e = getans(C);
cout << e.F + ans - (k * C) << '\n';
}
signed main() {
mispertion;
int t = 1;
//cin >> t;
while(t--){
solve();
}
return 0;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |