#include<bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#pragma GCC optimize("Ofast")
#pragma GCC optimize("unroll-loops")
#pragma GCC target("avx,avx2,fma")
using namespace __gnu_pbds;
#define ordered_set tree<pair<int, int>, null_type,less<pair<int, int>>, rb_tree_tag,tree_order_statistics_node_update>
#define ll long long
#define MOD 1000000007
#define MAXN 2e5
#define SIZE 314
#define pb push_back
using namespace std;
ll power(ll a, ll b){
if (b == 0) return 1;
ll res = power(a, b / 2);
// if (b % 2 == 1) return res * res % MOD * a % MOD;
// return res * res % MOD;
if (b % 2 == 1) return res * res * a;
return res * res;
}
void update(vector<int> &segTree, int curr, int pos, int l, int r){
segTree[curr]++;
if (l == r) return;
int mid = (r - l) / 2 + l;
if (pos <= mid){
update(segTree, curr * 2 + 1, pos, l, mid);
}
else{
update(segTree, curr * 2 + 2, pos, mid + 1, r);
}
}
int getCount(vector<int> &segTree, int curr, int l, int r, int findL, int findR){
if (l >= findL && r <= findR) return segTree[curr];
int mid = (r - l) / 2 + l;
int ans = 0;
if (findL <= mid){
ans += getCount(segTree, curr * 2 + 1, l, mid, findL, findR);
}
if (findR > mid){
ans += getCount(segTree, curr * 2 + 2, mid + 1, r, findL, findR);
}
return ans;
}
int getAns(int n, vector<int> &w, vector<int> &b, int k, vector<pair<int, int>> &allPair){
int curr = 0;
for (int i = 0; i < n; i++){
int wPos = w[i], bPos = b[(i + k) % n];
allPair[i] = {min(wPos, bPos), max(wPos, bPos)};
}
sort(allPair.begin(), allPair.end());
vector<int> segTree(8 * n, 0);
for (int i = 0; i < n; i++){
curr += getCount(segTree, 0, 0, 2 * n - 1, allPair[i].first, allPair[i].second);
update(segTree, 0, allPair[i].second, 0, 2 * n - 1);
}
return curr;
}
int main(){
ios_base::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
int n;
cin >> n;
string s;
cin >> s;
vector<int> w(n), b(n);
vector<pair<int, int>> allPair(n);
int curW = 0, curB = 0;
for (int i = 0; i < 2 * n; i++){
if (s[i] == 'W'){
w[curW++] = i;
}
else{
b[curB++] = i;
}
}
int ans = 0;
int left = 0, right = n - 1, mid;
while(left <= right){
mid = (right - left) / 2 + left;
int curr = getAns(n, w, b, mid, allPair);
int prev = 0;
if (mid - 1 >= 0){
prev = getAns(n, w, b, mid - 1, allPair);
}
ans = max(ans, curr);
if (curr > prev){
left = mid + 1;
}
else{
right = mid - 1;
}
}
ans = max(ans, getAns(n, w, b, 0, allPair));
ans = max(ans, getAns(n, w, b, n - 1, allPair));
cout << ans << "\n";
return 0;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |