#include "overtaking.h"
#include <bits/stdc++.h>
using namespace std;
#define ll long long
const int nx=1e3+5;
vector<ll> w, s;
ll m, speed;
struct info
{
ll reachtime, idx, nxt, dp;
info(ll reachtime, ll idx): reachtime(reachtime), idx(idx) {}
ll nextstationtime(ll gap) {return reachtime+gap*w[idx];}
bool operator< (const info &o) const {return reachtime==o.reachtime?w[o.idx]>w[idx]:reachtime<o.reachtime;}
};
vector<info> t[nx];
void init(int L, int N, std::vector<long long> T, std::vector<int> W, int X, int M, std::vector<int> S)
{
m=M;
speed=X;
for (auto x:W) w.push_back(x);
for (auto x:S) s.push_back(x);
for (int i=0; i<N; i++) if (w[i]>X) t[0].push_back(info(T[i], i));
for (int i=0; i<m; i++)
{
sort(t[i].begin(), t[i].end());
ll mx=0;
for (auto &x:t[i])
{
ll gap=0;
if (i==m-1) gap=0;
else gap=S[i+1]-S[i];
mx=max(mx, x.nextstationtime(gap));
x.nxt=mx;
t[i+1].push_back({mx, x.idx});
}
}
// for (int i=0; i<m-1; i++) for (int j=0; j<t[i].size(); j++) cout<<"debug "<<i<<' '<<j<<' '<<t[i][j].reachtime<<' '<<t[i][j].idx<<'\n';
for (int i=0; i<t[m-1].size(); i++) t[m-1][i].dp=t[m-1][i].reachtime;
for (int i=m-2; i>=0; i--)
{
for (int j=0; j<t[i].size(); j++)
{
// calculate t[i][j].dp
// find last element that < t[i][j].reach time -> if doesn't exist dp is trivial
int l=-1, r=j;
while (l<r)
{
int md=(l+r+1)/2;
if (t[i][md].reachtime<t[i][j].reachtime) l=md;
else r=md-1;
}
if (l==-1)
{
t[i][j].dp=t[i][j].reachtime+(s[m-1]-s[i])*speed;
continue;
}
// cout<<"here "<<i<<' '<<j<<'\n';
// binary search find first element where index l is overtaken
int L=i, R=m;
while (L<R)
{
int md=(L+R)/2;
if (t[md][l].reachtime>=(s[md]-s[i])*speed+t[i][j].reachtime) R=md;
else L=md+1;
}
if (L==m)
{
t[i][j].dp=t[i][j].reachtime+(s[m-1]-s[i])*speed;
continue;
}
t[i][j].dp=t[L][l].dp;
// cout<<"debug "<<i<<' '<<j<<'\n';
}
}
// cout<<"out\n";
}
long long arrival_time(long long Y)
{
if (t[0].empty()) return Y+s[m-1]*speed;
int l=-1, r=t[0].size()-1;
while (l<r)
{
int md=(l+r+1)/2;
if (t[0][md].reachtime<Y) l=md;
else r=md-1;
}
int idx=l;
if (idx==-1) return Y+s[m-1]*speed;
l=0, r=m;
while (l<r)
{
int md=(l+r)/2;
if (t[md][idx].reachtime>=(s[md]*speed)+Y) r=md;
else l=md+1;
}
if (l==m) return Y+s[m-1]*speed;
else return t[l][idx].dp;
}