Submission #1073699

#TimeUsernameProblemLanguageResultExecution timeMemory
1073699cadmiumskyTricks of the Trade (CEOI23_trade)C++17
50 / 100
6619 ms22144 KiB
#include <bits/stdc++.h>
#define all(x) (x).begin(),(x).end()
using namespace std;

using ll = long long;
using ld = long double;

#define int ll
#define sz(x) ((int)(x).size())

using pii = pair<int,int>;
using tii = tuple<int,int,int>;

const int nmax = 25e4 + 5;

int K;

const ll inf = 1e18;

struct KthHeap {
   multiset<int> outside, inside;
   void repair() {
      while(sz(inside) > K) {
         int x = *inside.begin();
         inside.erase(inside.find(x));
         outside.insert(x);
         sum -= x;
      }
      while(sz(inside) < K && sz(outside)) {
         int x = *outside.rbegin();
         outside.erase(outside.find(x));
         inside.insert(x);
         sum += x;
      }
      while(sz(inside) && sz(outside) && *inside.begin() < *outside.rbegin()) {
         int x = *outside.rbegin(), y = *inside.begin();
         outside.erase(outside.find(x));
         inside.erase(inside.find(y));
         outside.emplace(y);
         inside.emplace(x);
         sum += x - y;
      }
      return;
   }
   void erase(int x) {
      if(outside.find(x) != outside.end())
         outside.erase(outside.find(x));
      else if(inside.find(x) != inside.end())
         inside.erase(inside.find(x)), sum -= x;
      else assert(false);
      repair();
   }
   void insert(int x) {
      outside.emplace(x);
      repair();
   }
   ll query() {
      repair();
      if(sz(inside) < K) return -inf;
      return sum;
   }
   
   private:
      ll sum = 0;
};

int bestcut[nmax];
ll dp[nmax];

ll spart[nmax];
ll v[nmax];

ll S(int l, int r) { return spart[r] - spart[l - 1]; }

namespace Divide {
   KthHeap hint;
   
   //void brut(int p)  {
      //vector<int> pul;
      //int best = -inf, atr = p -1;
      //for(int i = p; i <= 200; i++) {
         //pul.emplace_back(v[i]);
         //if(sz(pul) < K) continue;
         //sort(all(pul), greater<int>());
         
         //if(accumulate(all(pul), 0ll) - S(p, i) > best) tie(best, atr) = make_pair(accumulate(all(pul), 0ll) - S(p, i), i);
      //}
      //if(bestcut[p] != atr)
         //cerr << "COAIE\nCOAIE\nCOAIE\nCOAIE\nCOAIE\nCOAIE\n";
      //cerr << p << ' ' << atr << '\n';
   //}
   
   void divide(int l, int r) {
      if(l + 1 == r) return;
      int optl = bestcut[l], optr = bestcut[r];
      
      //cerr << r << ' ' << optl << '\t' << sz(hint.inside) + sz(hint.outside) << '\n';
      
      int mid = l + r >> 1;
      if(r <= optl) {
         for(int i = mid; i < r; i++) hint.insert(v[i]);
         ll best = hint.query() - S(mid, optl);
         int atr = optl;
         for(int i = optl + 1; i <= optr; i++) {
            hint.insert(v[i]);
            if(hint.query() - S(mid, i) > best) tie(best, atr) = make_pair(hint.query() - S(mid, i), i);
         }
         bestcut[mid] = atr, dp[mid] = best;
         //cerr << mid << ' ' << best << ' ' << atr << '\t' << l << ' ' << r << " -- " << optl << ' ' << optr << '\t' << sz(hint.inside) + sz(hint.outside) << "\n\t";;
         //brut(mid); 
         //cerr << '\n';
         if(atr - mid + 1 == K) {
            int SA = S(mid, atr), SB = 0;
            for(int i = mid; i <= atr; i++) SB += v[i];
            //cout << "\t" << SB << ' ' << SA << '\t' << SB - SA<< '\n';
         }
         
         for(int i = optl + 1; i <= optr; i++) hint.erase(v[i]);
         divide(l, mid);
         for(int i = mid; i < r; i++) hint.erase(v[i]);
         for(int i = optl + 1; i <= atr; i++) hint.insert(v[i]);
         divide(mid, r);
         for(int i = optl + 1; i <= atr; i++) hint.erase(v[i]);
      }
      else {
         int border = max(mid - 1, optl);
         for(int i = mid; i <= border; i++) hint.insert(v[i]);
         ll best = hint.query() - S(mid, border), atr = border;
         for(int i = border + 1; i <= optr; i++) {
            hint.insert(v[i]);
            if(hint.query() - S(mid, i) > best) tie(best, atr) = make_pair(hint.query() - S(mid, i), i);
         }
         
         assert(atr - mid + 1 >= K);
         //cerr << mid << ' ' << best << ' ' << atr << '\n';
         //if(atr - mid + 1 == K) {
            //int SA = S(mid, atr), SB = 0;
            //for(int i = mid; i <= atr; i++) SB += v[i];
            //cout << "\t" << SB << ' ' << SA << '\t' << SB - SA<< '\n';
         //}
         
         bestcut[mid] = atr, dp[mid] = best;
         
         for(int i = optr; i > border; i--) hint.erase(v[i]);
         divide(l, mid);
         for(int i = mid; i <= border; i++) hint.erase(v[i]);
         for(int i = r; i <= atr; i++) hint.insert(v[i]);
         //cerr << '\t' << r << ' ' << atr << ' ' << sz(hint.inside) << '\n';
         divide(mid, r);
         for(int i = r; i <= atr; i++) hint.erase(v[i]);
      }
   }
   
}


signed main() {
   cin.tie(0) -> sync_with_stdio(0);
   int n;
   cin >> n >> K;
   for(int i = 1; i <= n; i++) {
      cin >> spart[i];
      spart[i] += spart[i - 1];
   }
   for(int i = 1; i <= n; i++)
      cin >> v[i];
      
   bestcut[0] = 0;
   bestcut[n - K + 2] = n;
   Divide::divide(0, n - K + 2);
   ll mn = -inf;
   for(int i = 1; i <= n - K + 1; i++) 
      mn = max(mn, dp[i]);

   cout << mn << '\n';
}


/**
      Töte es durch genaue Untersuchung\Töte es kann es nur noch schlimmer machen\Es lässt es irgendwie atmen
--
*/ 

Compilation message (stderr)

trade.cpp: In function 'void Divide::divide(ll, ll)':
trade.cpp:99:19: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
   99 |       int mid = l + r >> 1;
      |                 ~~^~~
trade.cpp:113:17: warning: unused variable 'SA' [-Wunused-variable]
  113 |             int SA = S(mid, atr), SB = 0;
      |                 ^~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...