Submission #1009454

#TimeUsernameProblemLanguageResultExecution timeMemory
1009454PotatoManOvertaking (IOI23_overtaking)C++17
100 / 100
1402 ms67432 KiB
#include "overtaking.h"
#include <bits/stdc++.h>
#define ll long long
#define pii pair<ll,ll>
#define mp make_pair
using namespace std;

struct Bus{
	ll W,S;
	Bus(ll w,ll s){
		W = w, S = s;
	}
};

bool cmp(Bus b1,Bus b2){
	if( b1.S == b2.S ) return b1.W < b2.W;
	return b1.S < b2.S; 
}

vector<Bus> bus;
vector<vector<pii>> AT(1005);
vector<ll> station;
ll n,m,x;
ll memo[1005][1005];

int FindP(int i,ll time){
	int cur = lower_bound(AT[i].begin(),AT[i].end(),mp(time,(ll)-1)) - AT[i].begin();
	int p = i, tl = i, tr = m-1;
	while( tl <= tr ){
		int tm = (tl + tr) / 2;
		ll exp_time = time + (station[tm]-station[i])*x;
		int cars_ahead = lower_bound(AT[tm].begin(),AT[tm].end(),mp(exp_time,(ll)-1)) - AT[tm].begin(); 
		if( cars_ahead == cur ){
			p = tm;
			tl = tm+1;
		}
		else{
			tr = tm-1;
		}
	}
	return p;
}

long long DP(int i,int j){
	if( memo[i][j] != -1 ) return memo[i][j];
	ll time = AT[i][j].first;
	int p = FindP(i,time);
	if( p == m-1 ) return memo[i][j] = time + (station[m-1]-station[i])*x;
	else{
		ll ET = time + (station[p]-station[i])*x;
		auto temp = lower_bound(AT[p].begin(),AT[p].end(),mp(ET,(ll)0));
		temp--; //long jump
		int next = lower_bound(AT[p+1].begin(),AT[p+1].end(),mp((*temp).second,(ll)0)) - AT[p+1].begin(); //normal jump
		return memo[i][j] = DP(p+1,next);
	}
}

void init(int L, int N, std::vector<long long> T, std::vector<int> W, int X, int M, std::vector<int> S)
{
	m = M, x = X;
	//Get rid of unnecessary buses
	for(int i = 0 ; i < N ; i++){
		if( W[i] > X ) bus.push_back(Bus(W[i],T[i]));
	}
	for(int i = 0 ; i < M ; i++){
		station.push_back(S[i]);
	}
	n = bus.size();
	for(int i = 0 ; i < m ; i++){
		ll mx = 0;
		sort(bus.begin(),bus.end(),cmp);
		for(int j = 0 ; j < n ; j++){
			ll temp = bus[j].S;
			if( i != m-1 ) mx = max(mx,bus[j].S+(station[i+1]-station[i])*bus[j].W);
			AT[i].push_back({temp,mx});
			bus[j].S = mx;
		}
	}
	memset(memo,-1,sizeof(memo));
}

long long arrival_time(long long Y)
{	
	int p = FindP(0,Y);
	if( p == m-1 ) return Y + station[m-1]*x;
	else{
		ll ET = Y + station[p]*x;
		auto temp = lower_bound(AT[p].begin(),AT[p].end(),mp(ET,(ll)0));
		temp--; //long jump
		int next = lower_bound(AT[p+1].begin(),AT[p+1].end(),mp((*temp).second,(ll)0)) - AT[p+1].begin(); //normal jump
		return DP(p+1,next);
	}
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...