| # | Time | Username | Problem | Language | Result | Execution time | Memory |
|---|---|---|---|---|---|---|---|
| 1283013 | takoshanava | Festival (IOI25_festival) | C++20 | 0 ms | 0 KiB |
#include <bits/stdc++.h>
#include "festival.h"
#define pb push_back
using namespace std;
const int INF = 1e18;
using ll = long long
vector<int> max_coupons(int AA, vector<int> P, vector<int> T){
ll A = AA;
ll n = (ll)P.size();
vector<array<ll,3>> mul;
vector<array<ll,2>> ones;
mul.reserve(n);
ones.reserve(n);
for (ll i = 0; i < n; i++){
if (T[i] == 1) ones.pb({P[i], i});
else mul.pb({P[i], T[i], i});
}
sort(ones.begin(), ones.end(), [](const array<ll,2>& a, const array<ll,2>& b){
if (a[0] != b[0]) return a[0] < b[0];
return a[1] < b[1];
});
sort(mul.begin(), mul.end(), [](const array<ll,3>& x, const array<ll,3>& y){
ll px = x[0], tx = x[1];
ll py = y[0], ty = y[1];
ll lhs = (px * tx + py) * ty;
ll rhs = (py * ty + px) * tx;
if (lhs != rhs) return lhs > rhs;
if (x[0] != y[0]) return x[0] < y[0];
return x[2] < y[2];
});
ll m = (ll)mul.size();
ll k = (ll)ones.size();
const ll MX = (ll)2e14;
const ll L = 71;
vector<vector<ll>> dp(m + 1, vector<ll>(L, -INF));
vector<vector<ll>> par(m + 1, vector<ll>(L, -1));
ll ptr = 0;
bool bad = false;
for (ll i = 0; i < m; i++){
ll p = mul[i][0];
ll t = mul[i][1];
if (!bad and A >= p){
ll after = (A - p) * 1LL * t;
if (after >= A){
A = (ll)min<ll>(after, MX);
++ptr;
continue;
}
}
bad = true;
}
for (ll i = 0; i <= m; i++) dp[i][0] = A;
for (ll i = ptr + 1; i <= m; i++){
ll idx = i - 1;
ll p = mul[idx][0];
ll t = mul[idx][1];
for (ll j = 0; j < L; j++){
if (dp[i-1][j] >= 0){
if (dp[i-1][j] > dp[i][j]){
dp[i][j] = dp[i-1][j];
par[i][j] = 0;
}
}
}
for (ll j = 0; j + 1 < L; j++){
if (dp[i-1][j] >= p){
ll after = (dp[i-1][j] - p) * 1LL * t;
if (after > MX) after = MX;
if ((ll)after > dp[i][j+1]){
dp[i][j+1] = (ll)after;
par[i][j+1] = 1;
}
}
}
}
vector<ll> ps(k + 1, 0);
for (ll i = 1; i <= k; i++) ps[i] = ps[i-1] + ones[i-1][0];
ll bt = -1, bj = -1, bo = -1;
for (ll j = 0; j < L; j++){
if (dp[m][j] < 0) continue;
ll lo = 0, hi = k;
while (lo < hi){
ll mid = (lo + hi + 1) >> 1;
if (ps[mid] <= dp[m][j]) lo = mid;
else hi = mid - 1;
}
ll total = j + lo;
if (total > bt){
bt = total;
bj = j;
bo = lo;
}
}
if (bt <= 0){
return vector<int>();
}
vector<int> ans;
ans.reserve(bt);
for (ll i = 0; i < bo; i++) ans.pb(ones[i][1]);
ll curj = bj;
for (ll i = m; i > ptr; i--){
if (par[i][curj] == 1){
ans.pb(mul[i-1][2]);
curj--;
}
}
for (ll i = ptr - 1; i >= 0; i--){
ans.pb(mul[i][2]);
}
reverse(ans.begin(), ans.end());
return ans;
}
