#include "festival.h"
#include<bits/stdc++.h>
#define pb push_back
using namespace std;
const long long maxn = 2e5 + 10;
long long n, a;
struct coupon
{
long long cost, type, index;
coupon(){};
coupon(long long _cost, long long _type, long long _index)
{
cost = _cost;
type = _type;
index = _index;
}
};
bool cmp(coupon c1, coupon c2)
{
return (c1.cost < c2.cost);
}
vector < coupon > g[5];
vector < long long > sums;
long long dp[maxn];
long long getone(long long tokens)
{
long long l = 0, r = sums.size()-1, mid = 0, ans = 0;
while(l <= r)
{
mid = (l + r)/2;
if(sums[mid] <= tokens)
{
ans = mid;
l = mid + 1;
}
else r = mid - 1;
}
return ans;
}
struct state
{
long long cnt1, cnt2, cnt3, cnt4;
state(){};
state(long long _cnt1, long long _cnt2, long long _cnt3, long long _cnt4)
{
cnt1 = _cnt1;
cnt2 = _cnt2;
cnt3 = _cnt3;
cnt4 = _cnt4;
}
bool operator==(const state & s)const
{
return (cnt1 == s.cnt1 && cnt2 == s.cnt2 && cnt3 == s.cnt3 && cnt4 == s.cnt4);
}
};
struct stateHash {
std::size_t operator()(const state& s) const {
const long long prime = 100;
std::size_t hash = s.cnt1; // Starting value
hash = hash * prime + std::hash<long long>()(s.cnt1+1);
hash = hash * prime + std::hash<long long>()(s.cnt2+1);
hash = hash * prime + std::hash<long long>()(s.cnt3+1);
hash = hash * prime + std::hash<long long>()(s.cnt4+1);
return hash;
}
};
unordered_map < state, long long, stateHash > mp;
unordered_map < state, bool, stateHash > valid;
unordered_map < state, long long, stateHash> from;
long long total1, total2, total3, total4;
long long mx;
state best;
void solve_70()
{
mp[state(-1, -1, -1, -1)] = a;
valid[state(-1, -1, -1, -1)] = 1;
mx = 0;
best = state(-1, -1, -1, -1);
for (long long p1 = -1; p1 < total1; ++ p1)
{
for (long long p2 = -1; p2 < total2; ++ p2)
{
for (long long p3 = -1; p3 < total3; ++ p3)
{
for (long long p4 = -1; p4 < total4; ++ p4)
{
if(valid[state(p1, p2, p3, p4)] == 0)continue;
///
long long all = p1 + 1 + p2 + 1 + p3 + 1 + p4 + 1;
// cout << "fff " << from[state(p1, p2, p3, p4)] << endl;
if(all > mx)
{
mx = all;
best = state(p1, p2, p3, p4);
}
long long tokens = mp[state(p1, p2, p3, p4)];
if(p1 < total1-1)
{
long long cost1 = g[1][p1+1].cost;
if(tokens >= cost1)
{
long long tokens1 = tokens - cost1;
tokens1 *= 1;
state newstate = state(p1+1, p2, p3, p4);
if(mp[newstate] <= tokens1)
{
valid[newstate] = 1;
mp[newstate] = max(mp[newstate], tokens1);
from[newstate] = 1;
}
}
}
if(p2 < total2-1)
{
long long cost2 = g[2][p2+1].cost;
if(tokens >= cost2)
{
long long tokens2 = tokens - cost2;
if(tokens2 < 1LL * 1e17)tokens2 *= 2;;
state newstate = state(p1, p2+1, p3, p4);
if(mp[newstate] <= tokens2)
{valid[newstate] = 1;
mp[newstate] = max(mp[newstate], tokens2);
from[newstate] = 2;
}
}
}
if(p3 < total3-1)
{
long long cost3 = g[3][p3+1].cost;
if(tokens >= cost3)
{
long long tokens3 = tokens - cost3;
if(tokens3 < 1LL * 1e17)tokens3 *= 3;
state newstate = state(p1, p2, p3+1, p4);
if(mp[newstate] <= tokens3)
{valid[newstate] = 1;
mp[newstate] = max(mp[newstate], tokens3);
from[newstate] = 3;
}
}
}
if(p4 < total4-1)
{
long long cost4 = g[4][p4+1].cost;
if(tokens >= cost4)
{
long long tokens4 = tokens - cost4;
if(tokens4 < 1LL * 1e17)tokens4 *= 4;
state newstate = state(p1, p2, p3, p4+1);
if(mp[newstate] <= tokens4)
{ valid[newstate] = 1;
mp[newstate] = max(mp[newstate], tokens4);
from[newstate] = 4;
}
}
}
}
}
}
}
}
std::vector<int> max_coupons(int A, std::vector<int> P, std::vector<int> T) {
a = A;
n = P.size();
for (long long i = 0; i < n; ++ i)
{
g[T[i]].pb(coupon(P[i], T[i], i));
}
for (long long curr = 1; curr <= 4; ++ curr)
sort(g[curr].begin(), g[curr].end(), cmp);
total1 = g[1].size();
total2 = g[2].size();
total3 = g[3].size();
total4 = g[4].size();
if(n <= 70)
{
solve_70();
vector < int > res;
state curr = best;
while(from[curr])
{
if(from[curr] == 1)
{
res.pb(g[1][curr.cnt1].index);
curr.cnt1 --;
continue;
}
if(from[curr] == 2)
{
res.pb(g[2][curr.cnt2].index);
curr.cnt2 --;
continue;
}
if(from[curr] == 1)
{
res.pb(g[1][curr.cnt1].index);
curr.cnt1 --;
continue;
}
if(from[curr] == 3)
{
res.pb(g[3][curr.cnt3].index);
curr.cnt3 --;
continue;
}
if(from[curr] == 4)
{
res.pb(g[4][curr.cnt4].index);
curr.cnt4 --;
continue;
}
}
reverse(res.begin(), res.end());
return res;
}
sums.pb(0);
long long pre = 0;
for (auto &[c, t, i]: g[1])
{
pre += c;
sums.pb(pre);
}
vector < int > res;
long long tokens = a;
dp[0] = getone(tokens);
long long mx = 0, best = dp[0];
long long cnt = 0;
for (auto &[c, t, i]: g[2])
{
if(tokens < c)break;
cnt ++;
tokens -= c;
if(tokens < 1LL * 1e17)tokens *= 2;
/// cout << "@ " << cnt << " " << i << endl;
/// cout << tokens << " for the ones " << endl;
dp[cnt] = cnt + getone(tokens);
if(best < dp[cnt])
{
best = dp[cnt];
mx = cnt;
}
}
for (long long i = 1; i <= mx; ++ i)
res.pb(g[2][i-1].index);
for (long long i = 1; i <= best - mx; ++ i)
res.pb(g[1][i-1].index);
return res;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |