제출 #1257388

#제출 시각아이디문제언어결과실행 시간메모리
1257388ro9669축제 (IOI25_festival)C++20
100 / 100
81 ms24244 KiB
#include "festival.h"
#include <bits/stdc++.h>
#define sz(a) int(a.size())
#define all(a) a.begin(),a.end()
#define fi first
#define se second
using namespace std;

typedef long long ll;
typedef pair<int , int> ii;
const ll inf = ll(1e18);
const int maxN = int(2e5)+7;

int w; vector<int> p , t;

ll cost(ll cur , int i){
    ll tmp = cur - 1ll * p[i];
    if (tmp <= inf / t[i]) return 1ll * tmp * t[i]; else return inf;
}

namespace sub5{
    vector<ii> g[4];
    int id[4];

    ll S[maxN] , dp[107][107][107];
    int trace[107][107][107];

    int calc(ll x){
        int l = 0 , r = sz(g[0]) , ans = 0;
        while (l <= r){
            int mid = (l + r) / 2;
            if (x >= S[mid]){
                ans = mid;
                l = mid + 1;
            }
            else{
                r = mid - 1;
            }
        }
        return ans;
    }

    vector<int> solve(){
        int n = sz(p);
        for (int i = 0 ; i < n ; i++){
            g[t[i] - 1].push_back({p[i] , i});
        }
        for (int i = 0 ; i < 4 ; i++){
            sort(all(g[i]));
            id[i] = 0;
        }
        S[0] = 0;
        for (int i = 0 ; i < sz(g[0]) ; i++){
            S[i + 1] = S[i] + 1ll * g[0][i].fi;
        }
        vector<int> tmp_lis;
        while (true){
            vector<int> tmp;
            for (int i = 1 ; i < 4 ; i++){
                if (id[i] < sz(g[i])) tmp.push_back(i);
            }
            if (tmp.empty()) break;
            for (int i = 0 ; i < sz(tmp) ; i++){
                for (int j = i + 1 ; j < sz(tmp) ; j++){
                    ll X = 1ll * (tmp[i] + 1) * tmp[j] * g[tmp[i]][id[tmp[i]]].fi;
                    ll Y = 1ll * tmp[i] * (tmp[j] + 1) * g[tmp[j]][id[tmp[j]]].fi;
                    if (Y <= X) swap(tmp[i] , tmp[j]);
                }
            }
            int pos = g[tmp[0]][id[tmp[0]]].se;
            tmp_lis.push_back(pos);
            id[tmp[0]]++;
        }
        memset(id , 0 , sizeof(id));
        ll cur = w;
        vector<int> ans;
        for (int i = 0 ; i < sz(tmp_lis) ; i++){
            if (cur <= cost(cur , tmp_lis[i])){
                cur = cost(cur , tmp_lis[i]);
                ans.push_back(tmp_lis[i]);
                id[t[tmp_lis[i]] - 1]++;
            }
        }
        memset(dp , -1 , sizeof(dp));
        memset(trace , -1 , sizeof(trace));
        dp[0][0][0] = cur;
        pair<int , int> res = {0 , calc(cur)};
        int x = -1 , y = -1 , z = -1;
        for (int i = 0 ; i <= min(sz(g[1]) - id[1] , 100) ; i++){
            for (int j = 0 ; j <= min(sz(g[2]) - id[2] , 100) ; j++){
                for (int k = 0 ; k <= min(sz(g[3]) - id[3] , 100) ; k++){
                    if (i > 0){
                        int pos = g[1][id[1] + i - 1].se;
                        ll tmp = cost(dp[i - 1][j][k] , pos);
                        if (dp[i][j][k] < tmp){
                            dp[i][j][k] = tmp;
                            trace[i][j][k] = pos;
                        }
                    }
                    if (j > 0){
                        int pos = g[2][id[2] + j - 1].se;
                        ll tmp = cost(dp[i][j - 1][k] , pos);
                        if (dp[i][j][k] < tmp){
                            dp[i][j][k] = tmp;
                            trace[i][j][k] = pos;
                        }
                    }
                    if (k > 0){
                        int pos = g[3][id[3] + k - 1].se;
                        ll tmp = cost(dp[i][j][k - 1] , pos);
                        if (dp[i][j][k] < tmp){
                            dp[i][j][k] = tmp;
                            trace[i][j][k] = pos;
                        }
                    }
                    if (dp[i][j][k] >= 0){
                        if (i + j + k + calc(dp[i][j][k]) >= res.fi + res.se){
                            res = {i + j + k , calc(dp[i][j][k])};
                            x = i; y = j; z = k;
                        }
                    }
                }
            }
        }
        vector<int> tmp_ans;
        while (x > 0 || y > 0 || z > 0){
            int pos = trace[x][y][z];
            tmp_ans.push_back(pos);
            if (t[pos] == 2) x--;
            if (t[pos] == 3) y--;
            if (t[pos] == 4) z--;
        }
        reverse(all(tmp_ans));
        for (int x : tmp_ans) ans.push_back(x);
        for (int i = 0 ; i < res.se ; i++) ans.push_back(g[0][i].se);
        return ans;
    }
}

vector<int> max_coupons(int W , vector<int> P , vector<int> T){
    w = W; p = P; t = T;
    return sub5::solve();
}

// int main(){
//     ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
//     freopen("templete.inp" , "r" , stdin);
//     freopen("templete.out" , "w" , stdout);
//     int n , W; cin >> n >> W;
//     vector<int> P(n) , T(n);
//     for (int &x : P) cin >> x;
//     for (int &x : T) cin >> x;
//     vector<int> ans = max_coupons(W , P , T);
//     cout << sz(ans) << "\n";
//     //for (int x : ans) cout << x << " ";
//     //cerr << sz(ans) << " " << n << "\n";
//     return 0;
// }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...