| # | Time | Username | Problem | Language | Result | Execution time | Memory | 
|---|---|---|---|---|---|---|---|
| 1252885 | comgaTramAnh | Festival (IOI25_festival) | C++20 | 0 ms | 0 KiB | 
#include <bits/stdc++.h> 
//#include "festival.h"
using namespace std;
std::vector <int> max_coupons(int A, std::vector <int> P, std::vector <int> T) {
  auto compare = [&](int i, int j) {
    std::pair <int, int> a = std::make_pair(P[i], T[i]); 
    std::pair <int, int> b = std::make_pair(P[j], T[j]); 
    return (1LL * a.first * a.second * b.second + 1LL * b.first * b.second < 1LL * b.first * a.second * b.second + 1LL * a.first * a.second);
  };
  auto compareOne = [&](int i, int j) {
    return P[i] < P[j]; 
  };
  const long long inf = 10000000000000007LL;
  std::vector <int> ones, other;
  for (int i = 0; i < (int) T.size(); i++) {
    if (T[i] == 1) {
      ones.push_back(i); 
    }
    else {
      other.push_back(i); 
    }
  } 
  std::vector <int> ret; 
  std::sort(ones.begin(), ones.end(), compareOne); 
  std::sort(other.begin(), other.end(), compare);
  long long X = A;
  bool ok = false; 
  for (int i = 0; i < (int) other.size(); i++) {
    long long nextX = (long long) (X - P[other[i]]) * T[other[i]];
    if (nextX < X) {     
      ok = true;
      ret = std::vector <int>(other.begin(), other.begin() + i);                     
      other.erase(other.begin(), other.begin() + i); 
      break;    
    }
    else {
      X = nextX; 
    }
    if (X >= inf) {      
      ret = other; 
      for (int i = 0; i < (int) ones.size(); i++) {
        ret.push_back(ones[i]); 
      }
      return ret; 
    }
  } 
  if (ok == false) {   
    ret = other; 
    other.clear(); 
  } 
  int n = (int) other.size(); 
  int numbT = std::min(32, (int) other.size());
  std::vector <std::vector <std::pair <int, int>>> trace(n + 1); 
  std::vector <std::vector <long long>> f(n + 1);
  for (int i = 0; i <= n; i++) {
    f[i].resize(numbT + 1, -inf); 
    trace[i].resize(numbT + 1, std::make_pair(-1, -1)); 
  }        
  f[0][0] = (long long) X;             
  for (int i = 0; i < n; i++) {
    for (int j = 0; j <= numbT; j++) {
      if (f[i][j] == -inf) {
        continue; 
      }
      if (f[i + 1][j] < f[i][j]) {
        f[i + 1][j] = f[i][j];
        trace[i + 1][j] = std::make_pair(i, j); 
      }   
      if (j < numbT && f[i][j] >= P[other[j]]) {
        long long next_f = std::min(inf, (long long) (f[i][j] - P[other[j]]) * T[other[j]]);
        if (f[i + 1][j + 1] < next_f) {
          f[i + 1][j + 1] = next_f; 
          trace[i + 1][j + 1] = std::make_pair(i, j); 
        } 
      }
    }
  }
  std::vector <long long> sum((int) ones.size() + 1, 0LL);
  for (int i = 1; i <= (int) ones.size(); i++) {
    sum[i] = sum[i - 1] + P[ones[i - 1]]; 
  } 
  int posj = -1; 
  int maxNumb = -1; 
  int posOne = -1; 
  for (int j = 0; j <= numbT; j++) {
    long long X = f[n][j];
    if (X == -inf) {
      continue; 
    }
    int numbOnes = -1; 
    int lo = 0; 
    int hi = (int) ones.size(); 
    while (lo <= hi) {
      int mid = (lo + hi) / 2; 
      if (X >= sum[mid]) {
        numbOnes = mid; 
        lo = mid + 1; 
      }
      else {
        hi = mid - 1; 
      }
    }
    if (numbOnes != -1 && maxNumb < j + numbOnes) {
      maxNumb = j + numbOnes; 
      posj = j; 
      posOne = numbOnes;
    }
  }
  std::vector <int> tmp; 
  int i = n; 
  while (i > 0) {
    std::pair <int, int> prev = trace[i][posj];
    if (prev.second != posj) {                     
      tmp.push_back(other[prev.second]); 
      posj = prev.second; 
    } 
    i = prev.first; 
  }
  std::reverse(tmp.begin(), tmp.end());
  for (int i = 0; i < (int) tmp.size(); i++) {
    ret.push_back(tmp[i]);
  }
  long long X = f[n][posj];                                              
  for (int i = 0; i < (int) ones.size(); i++) {
    if (X >= P[ones[i]]) {
      X -= P[ones[i]];
      ret.push_back(ones[i]);
    } 
  } 
  return ret;  
}
