//
// Created by adavy on 7/22/2025.
//
#include <bits/stdc++.h>
using namespace std;
#define int long long
using ll = long long;
const ll INF = 1e18;
// somehow seperate into consecutive blocks of R, B, R, B, R, B, ...
// do DP with 2 pointers
// do dp on prefixes for simplicity
#include "wiring.h"
long long calc_score(ll num_back, ll num_front, ll sum_back, ll sum_front, ll max_back, ll min_front){
return sum_front - sum_back - max(0LL, num_front-num_back)*max_back + max(0LL, num_back-num_front)*min_front;
}
long long min_total_length(std::vector<signed> r, std::vector<signed> b) {
vector<vector<int>> blocks;
vector<int> nums;
int rptr = 0, bptr = 0;
vector<int> curblock;
while (rptr < r.size() || bptr < b.size()){
curblock.clear();
while(bptr == b.size() || r[rptr] < b[bptr]){
curblock.push_back(r[rptr++]);
if (rptr == r.size()) break;
}
if (!curblock.empty()) blocks.push_back(curblock);
if (bptr == b.size() && rptr==r.size()) break;
curblock.clear();
while(rptr == r.size() || b[bptr] < r[rptr]){
curblock.push_back(b[bptr++]);
if (bptr == b.size()) break;
}
if (!curblock.empty()) blocks.push_back(curblock);
}
/*
for(auto&block:blocks){
for(auto&x:block) cout << x << " ";
cout << endl;
}*/
for(auto&block:blocks) for(auto&x:block) nums.push_back(x);
vector<ll> dp = {0}; // dp[i] = first i counters have been placed down
int i=0;
for(int j=0;j<blocks.size();++j){
if (j==0) {
for(int k=0;k<blocks[j].size();++k,++i) dp.push_back(INF);
continue;
}
int cur_back = i-1;
int sum_back = nums[cur_back];
int far_back = i-blocks[j-1].size(); // can't go further back than this
int num_back = 1;
int num_front, sum_front = 0;
int max_back = nums[i-1];
int min_front = nums[i];
for(int k=0;k<blocks[j].size();++k,++i){
num_front = k + 1; sum_front += nums[i];
ll best_score = dp[cur_back] + calc_score(num_back, num_front, sum_back, sum_front, max_back, min_front);
// update back
while (cur_back > far_back){
ll new_score = dp[cur_back-1] + calc_score(num_back+1, num_front, sum_back + nums[cur_back-1], sum_front, max_back, min_front);
if (new_score < best_score || j == 1){
//cout << i << " " << num_back << " " << num_front << " " << sum_back << " " << sum_front << " " << max_back << " " << min_front << endl;
best_score = new_score;
sum_back += nums[cur_back - 1];
cur_back--;
num_back++;
}
else{
break;
}
}
dp.push_back(best_score);
//cout << i+1 << " " << dp[i+1] << endl;
}
}
return dp.back();
}
/*
#include <cassert>
#include <cstdio>
signed main() {
int n, m;
assert(2 == scanf("%d %d", &n, &m));
vector<int> r(n), b(m);
for(int i = 0; i < n; i++)
assert(1 == scanf("%d", &r[i]));
for(int i = 0; i < m; i++)
assert(1 == scanf("%d", &b[i]));
long long res = min_total_length(r, b);
printf("%lld\n", res);
return 0;
}*/
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |