This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace std;
using namespace __gnu_pbds;
template<typename T> using Tree = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>;
typedef long long int ll;
typedef long double ld;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
#define fastio ios_base::sync_with_stdio(false); cin.tie(NULL)
#define pb push_back
#define endl '\n'
#define sz(a) (int)a.size()
#define setbits(x) __builtin_popcountll(x)
#define ff first
#define ss second
#define conts continue
#define ceil2(x,y) ((x+y-1)/(y))
#define all(a) a.begin(), a.end()
#define rall(a) a.rbegin(), a.rend()
#define yes cout << "Yes" << endl
#define no cout << "No" << endl
#define rep(i,n) for(int i = 0; i < n; ++i)
#define rep1(i,n) for(int i = 1; i <= n; ++i)
#define rev(i,s,e) for(int i = s; i >= e; --i)
#define trav(i,a) for(auto &i : a)
template<typename T>
void amin(T &a, T b) {
a = min(a,b);
}
template<typename T>
void amax(T &a, T b) {
a = max(a,b);
}
#ifdef LOCAL
#include "debug.h"
#else
#define debug(x) 42
#endif
/*
refs:
edi
key idea:
check if some assignment is possible => think about hall's theorem
not possible if |segments that cover at least one guy in s|-|s| < 0
now do a dp
dp[i] = min value if the last chosen guy is c[i]
transition from all dp[j]:
dp[i] = min(dp[j]+add), add = #of segments s.t c[j]+1 <= l <= c[i], r >= c[i]
counting the #of such segments can be done with a wavelet tree
works in O(m^2*log(n))
array c can be compressed to have no equal elements
in this case, #of distinct elements is bounded by O(sqrt(n))
so can achieve something like s*sqrt(n)*log(n)
can we optimize the transitions?
for a given (i,j,k) (i < j < k), let's say it's more optimal to pull from i than j
then for indices > k, i is more optimal than j
maintain a set of good guys
for each adj pair, find the first time when the greater guy is removed from the set (can again be done with a wavelet tree)
best j = largest guy in the set
compute dp[i] using this j
*/
const int MOD = 1e9 + 7;
const int N = 5e5 + 5;
const int inf1 = int(1e9) + 5;
const ll inf2 = ll(1e18) + 5;
#include "teams.h"
struct wavelet_tree{
struct node{
node *l, *r;
vector<int> b;
};
node* root;
int siz;
wavelet_tree(){
}
node* build(vector<int> &a, vector<int> inds, int l, int r){
if(l > r) return NULL;
int mid = (l+r) >> 1;
vector<int> b(sz(inds)+1);
vector<int> la,ra;
rep(i,sz(inds)){
int x = a[inds[i]];
b[i+1] = b[i]+(x <= mid);
if(x <= mid){
la.pb(inds[i]);
}
else{
ra.pb(inds[i]);
}
}
node* curr = new node();
curr->b = b;
if(l != r){
curr->l = build(a,la,l,mid);
curr->r = build(a,ra,mid+1,r);
}
return curr;
}
void build(vector<int> &a, int mx_siz){
siz = mx_siz;
vector<int> inds;
rep(i,sz(a)) inds.pb(i);
root = build(a,inds,0,siz);
}
int get_cnt(node* u, int lx, int rx, int l, int r, int vl, int vr){
if(!u) return 0;
if(l > r) return 0;
if(lx > vr or rx < vl) return 0;
if(lx >= vl and rx <= vr) return r-l+1;
int mid = (lx+rx) >> 1;
auto &b = u->b;
int res = 0;
res += get_cnt(u->l,lx,mid,b[l-1]+1,b[r],vl,vr);
res += get_cnt(u->r,mid+1,rx,l-b[l-1],r-b[r],vl,vr);
return res;
}
int get_cnt(int l, int r, int vl, int vr){
return get_cnt(root,0,siz,l+1,r+1,vl,vr);
}
int kth(node* u, int lx, int rx, int l, int r, int k){
if(lx == rx){
return lx;
}
int mid = (lx+rx) >> 1;
auto &b = u->b;
int cnt = b[r]-b[l-1];
if(k <= cnt){
return kth(u->l,lx,mid,b[l-1]+1,b[r],k);
}
else{
return kth(u->r,mid+1,rx,l-b[l-1],r-b[r],k-cnt);
}
}
int kth(int l, int r, int k){
return kth(root,0,siz,l+1,r+1,k);
}
int kth_largest(int l, int r, int k){
int len = r-l+1;
if(k > len) return -1;
return kth(l,r,len-k+1);
}
};
int n;
vector<pii> a;
wavelet_tree wt;
vector<int> lb;
void init(int n_, int A[], int B[]) {
n = n_;
rep(i,n) a.pb({A[i],B[i]});
sort(all(a));
lb = vector<int>(n+5,n);
rep(i,n) amin(lb[a[i].ff],i);
rev(i,n,0) amin(lb[i],lb[i+1]);
vector<int> b;
rep(i,n) b.pb(a[i].ss);
wt.build(b,n+1);
}
int can(int m, int K[]) {
int sum = 0;
rep(i,m){
sum += K[i];
if(sum > n){
return 0;
}
}
sort(K,K+m);
vector<pii> c;
c.pb({K[0],K[0]});
rep1(i,m-1){
if(K[i] == K[i-1]){
c.back().ss += K[i];
}
else{
c.pb({K[i],K[i]});
}
}
c.insert(c.begin(),{0,0});
m = sz(c);
vector<int> dp(m,inf1);
dp[0] = 0;
auto get = [&](int mnl, int mxl, int mnr){
return wt.get_cnt(lb[mnl],lb[mxl+1]-1,mnr,n);
};
auto first_bad = [&](int i, int j){
int diff = dp[j]-dp[i];
int l = lb[c[i].ff+1];
int r = lb[c[j].ff+1]-1;
int val = wt.kth_largest(l,r,diff);
if(val == -1) return -1;
if(diff < wt.get_cnt(l,r,val,n)) val++;
int pos = upper_bound(all(c),make_pair(val,-1))-c.begin();
return pos;
// int x = i, y = j;
// int val = dp[y]-dp[x];
// int lo = 0, hi = m-1;
// int pos = -1;
// while(lo <= hi){
// int mid = (lo+hi) >> 1;
// if(val >= get(c[x].ff+1,c[y].ff,c[mid].ff)){
// pos = mid;
// hi = mid-1;
// }
// else{
// lo = mid+1;
// }
// }
// return pos;
};
set<int> st;
st.insert(0);
vector<int> leave[m+5];
// auto odp = dp;
rep1(i,m-1){
while(!leave[i].empty()){
int j = leave[i].back();
leave[i].pop_back();
if(!st.count(j)) conts;
st.erase(j);
auto it = st.upper_bound(j);
if(it != st.end() and it != st.begin()){
int x = *prev(it), y = *it;
int pos = first_bad(x,y);
if(pos != -1){
amax(pos,i);
assert(pos <= m);
leave[pos].pb(y);
}
}
}
{
int j = *st.rbegin();
dp[i] = dp[j]+get(c[j].ff+1,c[i].ff,c[i].ff)-c[i].ss;
}
st.insert(i);
auto it = st.find(i);
{
int x = *prev(it), y = *it;
int pos = first_bad(x,y);
if(pos != -1){
amax(pos,i+1);
leave[pos].pb(y);
}
}
// rep(j,i){
// amin(odp[i],odp[j]+get(c[j].ff+1,c[i].ff,c[i].ff));
// }
// odp[i] -= c[i].ss;
}
// debug(dp);
// debug(odp);
// assert(dp == odp);
int mn = *min_element(all(dp));
return mn >= 0;
}
Compilation message (stderr)
teams.cpp: In lambda function:
teams.cpp:246:56: warning: conversion from '__gnu_cxx::__normal_iterator<std::pair<int, int>*, std::vector<std::pair<int, int> > >::difference_type' {aka 'long int'} to 'int' may change value [-Wconversion]
246 | int pos = upper_bound(all(c),make_pair(val,-1))-c.begin();
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |