제출 #1149116

#제출 시각아이디문제언어결과실행 시간메모리
1149116_8_8_Chorus (JOI23_chorus)C++17
100 / 100
2242 ms100652 KiB
#include <bits/stdc++.h>

using namespace std;
typedef long long ll;

const int  N = (int)1e6 + 10, MOD = 998244353, inf = (int)1e9;

#define int ll

int n, k, a[N], b[N], px[N * 2], py[N * 2], f[N * 2];
string s;

ll dop = 0;
void make() {
    int bal = 0;
    string ret;
    for(int i = 0, col = 0; i < n + n; ++i) {
        if(s[i] == 'A') {
            ret += s[i];
            dop += col;
            bal++;
        } else {
            if(bal) {
                bal--;
                ret += s[i];
            } else {
                col++;
            }
        }
        while(col && bal) {
            ret += 'B';
            bal--;
            col--;
        }
    }
    s = ret;
}
ll dp[N], col[N], pr[N];
int nxt[N];
struct line{
    ll k, b;
    int id;
    ll get(ll x) {
        return k * x + b;
    }
};

deque<line> st;

long double cross(line x, line y) {
    return ((x.b - y.b) * 1.0) / ((y.k - x.k) * 1.0);
}
void add(line nv) {
    while((int)st.size() >= 2) {
        int m = (int)st.size();
        if(cross(st[m - 1], st[m - 2]) >= cross(st[m - 2], nv)) {
            st.pop_back();
        } else {
            break;
        }
    }
    st.push_back(nv);
}
pair<ll, int> calc(ll x) {
    while((int)st.size() > 1 && cross(st[0], st[1]) < x * 1.0) st.pop_front();
    ll ret = st[0].get(x);
    ll _ = col[st[0].id] + 1;
    return {ret, _};
}
line w[N];
pair<ll, int> solve(ll pen) {
    st.clear();
    dp[0] = 0;
    col[0] = 0;
    for(int i = 1; i <= n; i++) {
        dp[i] = (ll)1e18;
        col[i] = 0;
    }
    w[0].k=0;w[0].b=0;w[0].id=0;
    int it = 0;
    for(int r = 1; r <= n; r++) {
        while(nxt[it] <= r) {
            add(w[it]);
            it++;
        }
        auto [d, c] = calc(r);
        dp[r] = d + pr[r] + pen;col[r] = c;
        w[r].k = -r;w[r].b = dp[r] - pr[nxt[r] - 1] + r * (nxt[r] - 1);
        w[r].id = r;
    }
    return {dp[n], col[n]};
}
void test() {
    cin >> n >> k;
    cin >> s;
    make();
    int fx = 1, fy = 1;
    for(int i = 1; i <= n + n; i++) {
        if(s[i - 1] == 'A') {
            a[fx++] = i;
        } else {
            b[fy++] = i;
            f[i]++;
        }
    }
    for(int i = 1; i <= n + n; i++) {
        f[i] += f[i - 1];
    }
    for(int i = 1; i <= n; i++) {
        pr[i] = pr[i - 1] + f[a[i]];
    }
    for(int i = 0; i <= n; i++) {
        if(i) {
            nxt[i] = max(i + 1, nxt[i - 1]);
        } else {
            nxt[i] = i + 1;
        }
        while(nxt[i] <= n && f[a[nxt[i]]] < i) {
            nxt[i]++;
        }
    }
    // auto [x, y] = solve(0);
    // cout << x << ' ' << y;
    // return;
    ll l = -1, r = (ll)1e14, res;
    while(r - l > 1) {
        ll mid = (l + r) >> 1;
        auto [val, t] = solve(mid);
        if(t <= k) {
            r = mid;
            res = val - k * 1ll * mid;
        } else {
            l = mid;
        }
    }
    cout << res + dop << '\n';
}
int32_t main() {
    ios_base::sync_with_stdio(false);
    cin.tie(0);

    int t = 1;
    // cin >> t;

    while(t--) {
        test();
    }
}

#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...