제출 #896068

#제출 시각아이디문제언어결과실행 시간메모리
896068cadmiumskyChorus (JOI23_chorus)C++17
100 / 100
1179 ms75892 KiB
#include <bits/stdc++.h> #define all(x) (x).begin(),(x).end() using namespace std; using ll = long long; using ld = long double; #define int ll #define sz(x) ((int)(x).size()) using pii = pair<int,int>; using tii = tuple<int,int,int>; const int nmax = 1e6 + 5; const ll inf = 1e9 + 5; int B[nmax]; ll sum_under[nmax], cnt_under[nmax]; ll sum_pref[nmax]; struct IWCnt { ll val; ll cnt; IWCnt(): val(0), cnt(0) {;} IWCnt(ll a, ll b): val(a), cnt(b) {;} IWCnt operator + (const IWCnt& x) const { return IWCnt(val + x.val, cnt + x.cnt); } IWCnt operator + (const ll& x) const { return IWCnt(val + x, cnt); } IWCnt operator * (const ll& x) const { return IWCnt(val * x, cnt); } }; namespace CHT { struct Line { ll m; IWCnt b; Line(ll a, IWCnt c): m(a), b(c) {;} ll operator()(const ll& x) const { return m * x + b.val; } }; bool bad(Line second, Line last, Line nv) { return (nv.b.val - last.b.val) * (second.m - nv.m) < (nv.b.val - second.b.val) * (last.m - nv.m); } vector<Line> st; int ptr; void push(Line a) { while(sz(st) > 1 && bad(rbegin(st)[1], rbegin(st)[0], a)) st.pop_back(); st.emplace_back(a); } pii query(int P) { ptr = min(ptr, sz(st) - 1); while(ptr + 1 < sz(st) && (st[ptr](P) > st[ptr + 1](P) || (st[ptr](P) == st[ptr + 1](P) && st[ptr + 1].b.cnt > st[ptr].b.cnt))) // ca gen, aia mari sunt teoretic primii intalniti, si ar trebui sa ii aflu pe ei pt a ma asigura ca-s pe panta corecta ptr++; return pii{st[ptr](P), st[ptr].b.cnt}; } void clear() { st.clear(); ptr = 0; } } IWCnt dp[nmax]; int n, k; bool check(ll lambda) { int last_nv = n; CHT::clear(); dp[n + 1].val = 0; dp[n + 1].cnt = 0; for(int i = n; i > 0; i--) { while(last_nv >= B[i]) { CHT::push(CHT::Line(last_nv, dp[last_nv + 1] + (-sum_under[last_nv] + cnt_under[last_nv] * last_nv + last_nv))); last_nv--; } auto [C, tp] = CHT::query(-i); dp[i].val = C + sum_pref[i - 1] + lambda; dp[i].cnt = tp + 1; } //cerr << lambda << ' ' << dp[1].cnt << '\t' << dp[1].val << '\n'; return dp[1].cnt >= k; } signed main() { cin.tie(0) -> sync_with_stdio(0); cin >> n >> k; int cnt[2] = {0, 0}; char ch; for(int i = 0; i < 2 * n; i++) { cin >> ch; if(ch == 'A') cnt[0]++; else B[++cnt[1]] = cnt[0]; //cerr << cnt[0] << '\n'; } ll fixing = 0; for(int i = 1; i <= n; i++) { int target = max(B[i - 1], i); fixing += max(0LL, target - B[i]); B[i] += max(0LL, target - B[i]); } for(int i = 1; i <= n; i++) cnt_under[B[i]]++, sum_under[B[i]] += B[i], sum_pref[i] = sum_pref[i - 1] + B[i]; for(int i = 1; i <= n; i++) cnt_under[i] += cnt_under[i - 1], sum_under[i] += sum_under[i - 1]; //for(int i = 1; i <= n; i++) //cout << B[i] << ' '; //cout << '\n'; ll lambda = 0; for(int i = 40; i >= 0; i--) { if(check(lambda + (1LL << i))) lambda += (1LL << i); } check(lambda); cout << dp[1].val - lambda * k + fixing << '\n'; }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...