#include <bits/stdc++.h>
using namespace std;
template<typename T> struct LinkedList;
void attemptMergeNext(LinkedList<long long>& list, int ind);
template<typename T> struct LinkedList{
int frontInd = -1;
int backInd = -1;
struct Node{
int next = -1;
int prev = -1;
int version = 0;
T val;
Node(){}
Node(T val){
this->val = val;
}
};
vector<Node> nodes;
void erase(int ind){
nodes[ind].version++;
if(nodes[ind].next != -1){
nodes[nodes[ind].next].prev = nodes[ind].prev;
}
if(nodes[ind].prev != -1){
nodes[nodes[ind].prev].next = nodes[ind].next;
}
if(frontInd == ind){
frontInd = nodes[ind].next;
}
if(backInd == ind){
backInd = nodes[ind].prev;
}
// if(nodes[ind].prev != -1){
// attemptMergeNext(*this, nodes[ind].prev);
// }
}
void merge(int ind){
nodes[ind].version++;
if(nodes[ind].next != -1){
nodes[ind].val += nodes[nodes[ind].next].val;
erase(nodes[ind].next);
}
if(nodes[ind].prev != -1){
nodes[ind].val += nodes[nodes[ind].prev].val;
erase(nodes[ind].prev);
}
}
void append(T val){
Node node = Node(val);
int ind = nodes.size();
if(backInd == -1){
frontInd = ind;
}else{
nodes[backInd].next = ind;
node.prev = backInd;
}
int prevInd = backInd;
backInd = ind;
nodes.push_back(node);
if(prevInd == -1) return;
bool shouldMerge = (nodes[ind].val == 0 || nodes[prevInd].val == 0 || (!((nodes[ind].val > 0) ^ (nodes[prevInd].val > 0))));
if(!shouldMerge) return;
merge(ind);
}
T& front(){
return nodes[frontInd].val;
}
T& back(){
return nodes[backInd].val;
}
Node get(int ind){
return nodes[ind];
}
};
void attemptMergeNext(LinkedList<long long>& list, int ind){
// int nextInd = list.nodes[ind].next;
// if(nextInd == -1) return;
// bool merge = (list.nodes[ind].val == 0 || list.nodes[nextInd].val == 0 || (!((list.nodes[ind].val > 0) ^ (list.nodes[nextInd].val > 0))));
// if(merge){
// list.mergeNext(ind);
// }
}
long long greedySolve(int n, int k, vector<int> inp){
vector<int> a;
for(int i=0; i<n; i++){
if(inp[i] == 0) continue;
a.push_back(inp[i]);
}
if(a.empty()) return 0;
n = a.size();
LinkedList<long long> segments;
segments.append(a.front());
for(int i=1; i<n; i++){
segments.append(a[i]);
}
if(segments.front() < 0){
segments.erase(segments.frontInd);
}
if(segments.frontInd == -1) return 0;
if(segments.back() < 0){
segments.erase(segments.backInd);
}
if(segments.backInd == -1) return 0;
long long ans = 0;
int segCount = 0;
priority_queue<pair<long long, pair<int, int>>> elemsToErase;
int ind = segments.frontInd;
while(ind != -1){
long long val = segments.get(ind).val;
elemsToErase.push({-abs(val), {ind, segments.get(ind).version}});
if(val > 0){
ans += val;
segCount++;
}
ind = segments.get(ind).next;
}
while(segCount > k){
pair<long long, pair<int, int>> elem = elemsToErase.top();
elemsToErase.pop();
int node = elem.second.first;
if(segments.get(node).version != elem.second.second) continue;
long long val = segments.get(node).val;
ans -= abs(val);
bool erase = segments.backInd == node || node == segments.frontInd;
segments.merge(node);
if(erase){
segments.erase(node);
}else{
elemsToErase.push({-abs(segments.get(node).val), {node, segments.get(node).version}});
}
segCount--;
}
return ans;
}
long long dpSolve(int n, int k, vector<int> a){
vector<vector<long long>> dp[2];
for(int i=0; i<2; i++){
dp[i] = vector<vector<long long>>(n, vector<long long>(k + 1, 0));
}
dp[1][0][1] = max(0, a[0]);
for(int i=1; i<n; i++){
for(int j=1; j<=k; j++){
dp[0][i][j] = max({dp[0][i][j], dp[0][i-1][j], dp[1][i-1][j]});
dp[1][i][j] = max({dp[1][i][j], dp[0][i-1][j-1] + a[i], dp[1][i-1][j] + a[i]});
}
}
long long bestAns = 0;
for(int i=1; i<=k; i++){
for(int j=0; j<2; j++){
bestAns = max(bestAns, dp[j][n-1][i]);
}
}
return bestAns;
}
void bruteforce(int t){
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
int successes = 0;
for(int j=0; j<t; j++){
uniform_int_distribution<int> nDist(1, 10);
int n = nDist(rng);
uniform_int_distribution<int> kDist(1, n);
int k = kDist(rng);
vector<int> a(n);
uniform_int_distribution<int> aDist(-10, 10);
for(int i=0; i<n; i++){
a[i] = aDist(rng);
}
int greedyAns = greedySolve(n, k, a);
int dpAns = dpSolve(n, k, a);
if(greedyAns != dpAns){
cout<<"MISMATCH FOUND!\n"<<n<<" "<<k<<"\n";
for(int i=0; i<n; i++){
cout<<a[i]<<" ";
}
cout<<"\nExpected: "<<dpAns<<"\nFound: "<<greedyAns<<"\n"<<endl;
}else{
successes++;
}
}
double accuracy = (double) successes / t * 100;
cout<<accuracy<<"% accuracy";
}
int main(){
cin.tie(nullptr);
ios_base::sync_with_stdio(false);
if(false){
bruteforce(10000);
}
int n, k;
cin>>n>>k;
vector<int> a(n);
for(int i=0; i<n; i++){
cin>>a[i];
}
// if(n <= 2000 || k == 1){
// cout<<dpSolve(n, k, a);
// }else{
cout<<greedySolve(n, k, a);
// }
}
/*
3 1
2 -1 3
*/
/*
*/
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |