제출 #1104079

#제출 시각아이디문제언어결과실행 시간메모리
1104079Ludissey추월 (IOI23_overtaking)C++17
19 / 100
6 ms2128 KiB
#include "overtaking.h"
#include <bits/stdc++.h>
#define sz(a) (int)a.size()
#define all(a) a.begin(),a.end()
#define rall(a) a.rbegin(), a.rend()
#define int long long
using namespace std;
int n,X,m;
const int LOG=12;
vector<pair<int,int>> b;
vector<vector<int>> s; // time of bus i a la station j
vector<vector<vector<int>>> dep; // depasser par
vector<vector<pair<int,int>>> sb; // sorted array of busses i at station
vector<vector<pair<int,int>>> sb2; // sorted array of busses i at station
vector<vector<int>> dp; // when it will arrive if it collides at n
vector<int> a;
vector<pair<pair<int,int>,int>> fseg;


int petit_lower(int Y, int j){
    int l=0,r=sz(sb2[j])-1;
    int ans=-1;
    while(l<=r){
        int mid=(l+r)/2;
        if(sb2[j][mid].first<Y){
            ans=abs(sb2[j][mid].second);
            l=mid+1;
        }else{
            r=mid-1;
        }
    }
    return ans;
}

void init(signed L, signed N, std::vector<long long> T, std::vector<signed> W, signed _x, signed M, std::vector<signed> _s)
{
    n=N; X=_x; m=M;
    a.resize(m+1);
    sb.resize(m+1);
    sb2.resize(m+1);
    a[m]=L;
    for (int i = 0; i < m; i++) a[i]=_s[i];
    for (int i = 0; i < n; i++) {
        if(W[i]<X) continue;
        b.push_back({W[i],T[i]});
    }
    n=sz(b);
    dep.resize(n,vector<vector<int>>(m,vector<int>(LOG,-1)));
    s.resize(n, vector<int>(m+1));
    sort(rall(b));
    for (int i = 0; i < n; i++) { sb[0].push_back({b[i].second,i}); sb2[0].push_back({b[i].second,-i}); s[i][0]=b[i].second; }
    for (int i = 1; i < m; i++)
    {
        sort(all(sb[i-1]));
        sort(all(sb2[i-1]));
        int mx=0;
        int cmx=0;
        int cmxI=-1;
        for (int j = 0; j < n; j++)
        {
            int x=sb[i-1][j].second;
            if(j>0&&sb[i-1][j].first!=sb[i-1][j-1].first){
                mx=max(mx,cmx);
                cmx=mx;
            }
            s[x][i]=max(mx,sb[i-1][j].first+b[x].first*(a[i]-a[i-1]));
            if(s[x][i]>cmx) cmxI=x;
            cmx=max(s[x][i],cmx);
            dep[x][i-1][0]=cmxI;
            sb[i].push_back({s[x][i],x});    
            sb2[i].push_back({s[x][i],-x});    
        }
    }
    sort(all(sb2[m-1]));

    for (int k = 1; k < LOG; k++)
    {
        for (int i = 0; i < n; i++)
        {
            for (int j = 0; j < m; j++)
            {
                int p=1LL<<(k-1);
                if(p+j>=m) continue;
                if(dep[i][j][k-1]==-1) dep[i][j][k]=-1;
                else {
                    dep[i][j][k]=dep[dep[i][j][k-1]][j+p][k-1];
                }
            }
        }
    }

    dp.resize(n,vector<int>(m));
    for (int i = 0; i < n; i++)
    {
        dp[i][m-1]=s[i][m-1];
    }
    
    for (int i = m-2; i >= 0; i--)
    {
        for (int j = 0; j < n; j++)
        {
            int l=1; int r=(m-1)-i;
            int ans=-1;
            int ansX=-1;
            int _j=petit_lower(s[j][i],i);
            while(l<=r&&_j>=0){
                int mid=(l+r)/2;
                int _mid=mid;
                int x=_j;
                int dst=i;
                for (int k = LOG-1; k >= 0; k--)
                {
                    if(mid&(1<<k)){
                        x=dep[x][dst][k];
                        dst+=(1<<k);
                        _mid-=(1<<k);
                        if(x<0) break;
                    }
                }
                if(x>=0&&s[x][dst]>=X*(a[dst]-a[i])){
                    r=mid-1;
                    ans=dst;
                    ansX=x;
                }else if(x==-1){
                    r=mid-1;
                }
                else{
                    l=mid+1;
                }
            }
            if(ans==-1){
                dp[j][i]=(X*(a[m-1]-a[i]))+s[j][i];
            }else{
                dp[j][i]=dp[ansX][ans];
            }
        }
    }
    return;
}


long long arrival_time(long long Y) 
{
    int l=1; int r=m-1;
    int ans=-1;
    int ansX=-1;
    int pl=petit_lower(Y,0);
    if(pl==-1) return X*a[m-1]+Y;
    while(l<=r){
        int mid=(l+r)/2;
        int _mid=mid;
        int x=pl;
        int dst=0;
        for (int k = LOG-1; k >= 0; k--)
        {
            if(mid&(1<<k)){
                x=dep[x][dst][k];
                dst+=(1<<k);
                _mid-=(1<<k);
                if(x<0) break;
            }
        }
        if(x>=0&&s[x][dst]>=X*(a[dst])+Y){
            r=mid-1;
            ans=dst;
            ansX=x;
        }else if(x==-1){
            r=mid-1;
        }
        else{
            l=mid+1;
        }
    }
    if(ans==-1){
        return X*a[m-1]+Y;
    }else{
        return dp[ansX][ans];
    }
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...