제출 #855932

#제출 시각아이디문제언어결과실행 시간메모리
855932Sorting추월 (IOI23_overtaking)C++17
65 / 100
3558 ms153096 KiB
#include "overtaking.h"
#include <algorithm>
#include <map>
#include <numeric>

using namespace std;
typedef long long ll;

const int N = 1000 + 3;

ll dp[N][N];
int n, m;
ll t[N], w[N], s[N], l, x;

map<ll, pair<vector<int>, int>> groups[N];
ll time_at_sorting[N][N];

void reorder_buses(){
    int new_n = n;
    for(int i = 0; i < new_n; ++i){
        if(w[i] <= x){
            --new_n;
            swap(w[i], w[new_n]);
            swap(t[i], t[new_n]);
            --i;
        }
    }
    n = new_n;

    vector<int> perm(n);
    iota(perm.begin(), perm.end(), 0);

    sort(perm.begin(), perm.end(), [&](int l, int r){
        if(w[l] != w[r])
            return w[l] < w[r];
        return t[l] < t[r];
    });

    static ll tmp[N];
    copy(t, t + n, tmp);
    for(int i = 0; i < n; ++i){
        t[i] = tmp[perm[i]];
    }
    copy(w, w + n, tmp);
    for(int i = 0; i < n; ++i){
        w[i] = tmp[perm[i]];
    }
}

void init_groups(){
    for(int i = 0; i < n; ++i){
        time_at_sorting[0][i] = t[i];
        groups[0][t[i]].first.push_back(i);
    }

    {
        int curr_cnt = 0;
        for(auto &[time, p]: groups[0]){
            curr_cnt += p.first.size();
            p.second = curr_cnt;
        }
    }

    for(int station = 1; station < m; ++station){
        vector<int> order;
        for(auto &[time, p]: groups[station - 1]){
            auto &v = p.first;
            for(int x: v){
                order.push_back(x);
            }
        }

        for(int i = 0; i < n; ++i){
            time_at_sorting[station][i] = time_at_sorting[station - 1][i] + (s[station] - s[station - 1]) * w[i];
        }

        ll max_so_far = 0;
        for(int x: order){
            max_so_far = max(max_so_far, time_at_sorting[station][x]);
            time_at_sorting[station][x] = max_so_far;
            groups[station][max_so_far].first.push_back(x);
        }

        int curr_cnt = 0;
        for(auto &[time, p]: groups[station]){
            auto &buses = p.first;
            auto &cnt = p.second;
            sort(buses.begin(), buses.end());
            curr_cnt += buses.size();
            cnt = curr_cnt;
        }
    }
}

void init_buses(){
    reorder_buses();
    init_groups();
}

ll calc_time(int station, ll time){
    if(station == m - 1){
        return time;
    }

    int curr_cnt = 0;
    {
        auto curr_iter = groups[station].lower_bound(time);
        if(curr_iter == groups[station].begin()){
            curr_cnt = 0;
        }
        else{
            --curr_iter;
            curr_cnt = curr_iter->second.second;
        }
    }

    int l = station + 1, r = m;
    while(l != r){
        int mid = (l + r) >> 1;
        
        ll time_at_mid = (s[mid] - s[station]) * x + time;

        auto iter = groups[mid].lower_bound(time_at_mid);
        if(iter == groups[mid].begin()){
            r = mid;
            continue;
        }

        --iter;
        int cnt = iter->second.second;

        if(cnt < curr_cnt){
            r = mid;
        }
        else{
            l = mid + 1;
        }
    }

    int next_station = l;
    if(next_station == m){
        return (s[m - 1] - s[station]) * x + time;
    }

    int prev_station = next_station - 1;
    ll time_at_prev_station = (s[prev_station] - s[station]) * x + time;

    auto iter = groups[prev_station].lower_bound(time_at_prev_station);
    if(iter == groups[prev_station].begin()){
        return (s[m - 1] - s[station]) * x + time;
    }

    --iter;
    int bus = iter->second.first.back();
    return dp[next_station][bus];
}

void calc_dp(){
    for(int station = m - 1; station >= 0; --station){
        if(station == m - 1){
            for(int i = 0; i < n; ++i){
                dp[station][i] = time_at_sorting[station][i];
            }
            continue;
        }

        for(int i = 0; i < n; ++i){
            dp[station][i] = calc_time(station, time_at_sorting[station][i]);
        }
    }
}

void init(int L, int N, vector<long long> T, vector<int> W, int X, int M, vector<int> S)
{
    l = L, n = N, x = X, m = M;
    copy(T.begin(), T.end(), t);
    copy(W.begin(), W.end(), w);
    copy(S.begin(), S.end(), s);

    init_buses();
    calc_dp();
}

ll arrival_time(ll Y)
{
    return calc_time(0, Y);
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...