Submission #1150881

#TimeUsernameProblemLanguageResultExecution timeMemory
1150881AmirAli_H1Chorus (JOI23_chorus)C++20
40 / 100
7097 ms99308 KiB
// In the name of Allah

#include <bits/stdc++.h>
using namespace std;

typedef		long long int			ll;
typedef		long double				ld;
typedef		pair<int, int>			pii;
typedef		pair<ll, ll>			pll;
typedef		complex<ld>				cld;

#define		all(x)					(x).begin(),(x).end()
#define		len(x)					((ll) (x).size())
#define		F						first
#define		S						second
#define		pb						push_back
#define		sep						' '
#define		endl					'\n'
#define		Mp						make_pair
#define		kill(x)					cout << x << '\n', exit(0)
#define		set_dec(x)				cout << fixed << setprecision(x);
#define		file_io(x,y)			freopen(x, "r", stdin); freopen(y, "w", stdout);
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());

const int maxn = (1 << 20) + 4;
const ll oo = 1e18 + 4;
const ll inf = 1e12 + 4;

struct node {
	ll lazy; pll sum;
	ll valx; pll max;
};

int n, k; string s;
vector<ll> ls; pll dp[maxn];
node t[2 * maxn], null; set<int> st;

node f(node a, node b) {
	node res; res.lazy = 0;
	res.valx = a.valx + b.valx;
	res.sum.F = a.sum.F + b.sum.F;
	res.sum.S = a.sum.S + b.sum.S;
	res.max = max(a.max, b.max);
	return res;
}

void shift(int v, int tl, int tr) {
	ll x = t[v].lazy; t[v].lazy = 0;
	if (x == 0 || tr - tl == 1) return ;
	for (int u : {2 * v + 1, 2 * v + 2}) {
		t[u].lazy += x; t[u].max.F += x;
		t[u].sum.F += t[u].valx * x;
	}
}

void build(int v, int tl, int tr) {
	t[v].lazy = 0;
	if (tr - tl == 1) {
		t[v] = null;
		return ;
	}
	int mid = (tl + tr) / 2;
	build(2 * v + 1, tl, mid); build(2 * v + 2, mid, tr);
	t[v] = f(t[2 * v + 1], t[2 * v + 2]);
}

void set_val(int v, int tl, int tr, int i, pll x, int valx) {
	shift(v, tl, tr);
	if (tr - tl == 1) {
		if (valx == 0) {
			t[v] = null;
		}
		else {
			t[v].sum = x; t[v].valx = valx;
			if (x.F > 0) {
				t[v].max = Mp(0, tl);
			}
			else {
				ll R1 = -x.F;
				if (R1 % valx != 0) {
					t[v].max = Mp((R1 / valx) + 1, tl);
				}
				else {
					t[v].max = Mp((R1 / valx) + (x.S < 0), tl);
				}
			}
			t[v].max.F = (-t[v].max.F);
		}
		return ;
	}
	int mid = (tl + tr) / 2;
	if (i < mid) set_val(2 * v + 1, tl, mid, i, x, valx);
	else set_val(2 * v + 2, mid, tr, i, x, valx);
	t[v] = f(t[2 * v + 1], t[2 * v + 2]);
}

void add_val(int v, int tl, int tr, int l, int r, ll x) {
	l = max(l, tl); r = min(r, tr);
	if (l >= tr || r <= tl) return ;
	shift(v, tl, tr);
	if (l == tl && r == tr) {
		t[v].lazy += x; t[v].max.F += x;
		t[v].sum.F += t[v].valx * x;
		return ;
	}
	int mid = (tl + tr) / 2;
	add_val(2 * v + 1, tl, mid, l, r, x); add_val(2 * v + 2, mid, tr, l, r, x);
	t[v] = f(t[2 * v + 1], t[2 * v + 2]);
}

node get_res(int v, int tl, int tr, int l, int r) {
	l = max(l, tl); r = min(r, tr);
	if (l >= tr || r <= tl) return null;
	shift(v, tl, tr);
	if (l == tl && r == tr) return t[v];
	int mid = (tl + tr) / 2;
	return f(get_res(2 * v + 1, tl, mid, l, r), get_res(2 * v + 2, mid, tr, l, r));
}

void updx(int i, ll valx) {
	auto f = dp[i]; f.F += valx; f.S += 1;
	if (len(st) > 0) {
		int j = (*st.rbegin());
		node res = get_res(0, 0, n + 1, j, j + 1);
		auto g = res.sum; g.F -= (f.F * res.valx); g.S -= (f.S * res.valx);
		set_val(0, 0, n + 1, j, g, res.valx);
	}
	set_val(0, 0, n + 1, i, f, 1); st.insert(i);
	// checkx();
}

void cal(ll valx) {
	build(0, 0, n + 1); st.clear();
	dp[0] = Mp(0, 0); updx(0, valx);
	for (int i = 1; i <= n; i++) {
		// addx(ls[i - 1]);
		add_val(0, 0, n + 1, 0, ls[i - 1], 1);
		dp[i] = Mp(oo, oo);
		for (int j = 0; j < i; j++) dp[i] = min(dp[i], get_res(0, 0, n + 1, j, n + 1).sum);
		updx(i, valx);
	}
}

void solve() {
	cin >> n >> k >> s;
	ll res = 0; int x = 0;
	for (int i = 0; i < (2 * n); i++) {
		if (s[i] == 'B') x++;
		else ls.pb(x);
	}
	for (int i = 0; i < n; i++) {
		if (ls[i] > i) {
			res += (ls[i] - i); ls[i] = i;
		}
	}

	ll l = -1, r = inf;
	while (r - l > 1) {
		ll mid = (l + r) / 2; cal(mid);
		if (dp[n].S > k) l = mid;
		else r = mid;
	}
	cal(r);
	cout << res + (dp[n].F - (k * r)) << endl;
}

int main() {
	ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);

	null.lazy = 0; null.valx = 0;
	null.sum = Mp(0, 0); null.max = Mp(-oo, -1);

	int T = 1;
	while (T--) {
		solve();
	}
	
	return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...