제출 #1267275

#제출 시각아이디문제언어결과실행 시간메모리
1267275nerrrminFestival (IOI25_festival)C++20
39 / 100
66 ms12528 KiB
#include "festival.h" #include<bits/stdc++.h> #define pb push_back using namespace std; const long long maxn = 2e5 + 10; long long n, a; struct coupon { long long cost, type, index; coupon(){}; coupon(long long _cost, long long _type, long long _index) { cost = _cost; type = _type; index = _index; } }; bool cmp(coupon c1, coupon c2) { return (c1.cost < c2.cost); } vector < coupon > g[5]; vector < long long > sums; long long dp[maxn]; long long getone(long long tokens) { long long l = 0, r = sums.size()-1, mid = 0, ans = 0; while(l <= r) { mid = (l + r)/2; if(sums[mid] <= tokens) { ans = mid; l = mid + 1; } else r = mid - 1; } return ans; } struct state { long long cnt1, cnt2, cnt3, cnt4; state(){}; state(long long _cnt1, long long _cnt2, long long _cnt3, long long _cnt4) { cnt1 = _cnt1; cnt2 = _cnt2; cnt3 = _cnt3; cnt4 = _cnt4; } bool operator==(const state & s)const { return (cnt1 == s.cnt1 && cnt2 == s.cnt2 && cnt3 == s.cnt3 && cnt4 == s.cnt4); } }; struct stateHash { std::size_t operator()(const state& s) const { const long long prime = 100; std::size_t hash = s.cnt1; // Starting value hash = hash * prime + std::hash<long long>()(s.cnt1+1); hash = hash * prime + std::hash<long long>()(s.cnt2+1); hash = hash * prime + std::hash<long long>()(s.cnt3+1); hash = hash * prime + std::hash<long long>()(s.cnt4+1); return hash; } }; unordered_map < state, long long, stateHash > mp; unordered_map < state, bool, stateHash > valid; unordered_map < state, long long, stateHash> from; long long total1, total2, total3, total4; long long mx; state best; void solve_70() { mp[state(-1, -1, -1, -1)] = a; valid[state(-1, -1, -1, -1)] = 1; mx = 0; best = state(-1, -1, -1, -1); for (long long p1 = -1; p1 < total1; ++ p1) { for (long long p2 = -1; p2 < total2; ++ p2) { for (long long p3 = -1; p3 < total3; ++ p3) { for (long long p4 = -1; p4 < total4; ++ p4) { if(valid[state(p1, p2, p3, p4)] == 0)continue; /// long long all = p1 + 1 + p2 + 1 + p3 + 1 + p4 + 1; // cout << "fff " << from[state(p1, p2, p3, p4)] << endl; if(all > mx) { mx = all; best = state(p1, p2, p3, p4); } long long tokens = mp[state(p1, p2, p3, p4)]; if(p1 < total1-1) { long long cost1 = g[1][p1+1].cost; if(tokens >= cost1) { long long tokens1 = tokens - cost1; tokens1 *= 1; state newstate = state(p1+1, p2, p3, p4); if(mp[newstate] <= tokens1) { valid[newstate] = 1; mp[newstate] = max(mp[newstate], tokens1); from[newstate] = 1; } } } if(p2 < total2-1) { long long cost2 = g[2][p2+1].cost; if(tokens >= cost2) { long long tokens2 = tokens - cost2; if(tokens2 < 1LL * 1e17)tokens2 *= 2;; state newstate = state(p1, p2+1, p3, p4); if(mp[newstate] <= tokens2) {valid[newstate] = 1; mp[newstate] = max(mp[newstate], tokens2); from[newstate] = 2; } } } if(p3 < total3-1) { long long cost3 = g[3][p3+1].cost; if(tokens >= cost3) { long long tokens3 = tokens - cost3; if(tokens3 < 1LL * 1e17)tokens3 *= 3; state newstate = state(p1, p2, p3+1, p4); if(mp[newstate] <= tokens3) {valid[newstate] = 1; mp[newstate] = max(mp[newstate], tokens3); from[newstate] = 3; } } } if(p4 < total4-1) { long long cost4 = g[4][p4+1].cost; if(tokens >= cost4) { long long tokens4 = tokens - cost4; if(tokens4 < 1LL * 1e17)tokens4 *= 4; state newstate = state(p1, p2, p3, p4+1); if(mp[newstate] <= tokens4) { valid[newstate] = 1; mp[newstate] = max(mp[newstate], tokens4); from[newstate] = 4; } } } } } } } } std::vector<int> max_coupons(int A, std::vector<int> P, std::vector<int> T) { a = A; n = P.size(); for (long long i = 0; i < n; ++ i) { g[T[i]].pb(coupon(P[i], T[i], i)); } for (long long curr = 1; curr <= 4; ++ curr) sort(g[curr].begin(), g[curr].end(), cmp); total1 = g[1].size(); total2 = g[2].size(); total3 = g[3].size(); total4 = g[4].size(); if(n <= 70) { solve_70(); vector < int > res; state curr = best; while(from[curr]) { if(from[curr] == 1) { res.pb(g[1][curr.cnt1].index); curr.cnt1 --; continue; } if(from[curr] == 2) { res.pb(g[2][curr.cnt2].index); curr.cnt2 --; continue; } if(from[curr] == 1) { res.pb(g[1][curr.cnt1].index); curr.cnt1 --; continue; } if(from[curr] == 3) { res.pb(g[3][curr.cnt3].index); curr.cnt3 --; continue; } if(from[curr] == 4) { res.pb(g[4][curr.cnt4].index); curr.cnt4 --; continue; } } reverse(res.begin(), res.end()); return res; } sums.pb(0); long long pre = 0; for (auto &[c, t, i]: g[1]) { pre += c; sums.pb(pre); } vector < int > res; long long tokens = a; dp[0] = getone(tokens); long long mx = 0, best = dp[0]; long long cnt = 0; for (auto &[c, t, i]: g[2]) { if(tokens < c)break; cnt ++; tokens -= c; if(tokens < 1LL * 1e17)tokens *= 2; /// cout << "@ " << cnt << " " << i << endl; /// cout << tokens << " for the ones " << endl; dp[cnt] = cnt + getone(tokens); if(best < dp[cnt]) { best = dp[cnt]; mx = cnt; } } for (long long i = 1; i <= mx; ++ i) res.pb(g[2][i-1].index); for (long long i = 1; i <= best - mx; ++ i) res.pb(g[1][i-1].index); return res; }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...