Submission #929097

#TimeUsernameProblemLanguageResultExecution timeMemory
929097GrindMachineTeams (IOI15_teams)C++17
100 / 100
1148 ms251956 KiB
#include <bits/stdc++.h> #include <ext/pb_ds/assoc_container.hpp> #include <ext/pb_ds/tree_policy.hpp> using namespace std; using namespace __gnu_pbds; template<typename T> using Tree = tree<T, null_type, less<T>, rb_tree_tag, tree_order_statistics_node_update>; typedef long long int ll; typedef long double ld; typedef pair<int,int> pii; typedef pair<ll,ll> pll; #define fastio ios_base::sync_with_stdio(false); cin.tie(NULL) #define pb push_back #define endl '\n' #define sz(a) (int)a.size() #define setbits(x) __builtin_popcountll(x) #define ff first #define ss second #define conts continue #define ceil2(x,y) ((x+y-1)/(y)) #define all(a) a.begin(), a.end() #define rall(a) a.rbegin(), a.rend() #define yes cout << "Yes" << endl #define no cout << "No" << endl #define rep(i,n) for(int i = 0; i < n; ++i) #define rep1(i,n) for(int i = 1; i <= n; ++i) #define rev(i,s,e) for(int i = s; i >= e; --i) #define trav(i,a) for(auto &i : a) template<typename T> void amin(T &a, T b) { a = min(a,b); } template<typename T> void amax(T &a, T b) { a = max(a,b); } #ifdef LOCAL #include "debug.h" #else #define debug(x) 42 #endif /* refs: edi */ const int MOD = 1e9 + 7; const int N = 5e5 + 5; const int inf1 = int(1e9) + 5; const ll inf2 = ll(1e18) + 5; #include "teams.h" struct wavelet_tree{ struct node{ node *l, *r; vector<int> b; }; node* root; int siz; wavelet_tree(){ } node* build(vector<int> &a, vector<int> inds, int l, int r){ if(l > r) return NULL; int mid = (l+r) >> 1; vector<int> b(sz(inds)+1); vector<int> la,ra; rep(i,sz(inds)){ int x = a[inds[i]]; b[i+1] = b[i]+(x <= mid); if(x <= mid){ la.pb(inds[i]); } else{ ra.pb(inds[i]); } } node* curr = new node(); curr->b = b; if(l != r){ curr->l = build(a,la,l,mid); curr->r = build(a,ra,mid+1,r); } return curr; } void build(vector<int> &a, int mx_siz){ siz = mx_siz; vector<int> inds; rep(i,sz(a)) inds.pb(i); root = build(a,inds,0,siz); } int get_cnt(node* u, int lx, int rx, int l, int r, int vl, int vr){ if(!u) return 0; if(l > r) return 0; if(lx > vr or rx < vl) return 0; if(lx >= vl and rx <= vr) return r-l+1; int mid = (lx+rx) >> 1; auto &b = u->b; int res = 0; res += get_cnt(u->l,lx,mid,b[l-1]+1,b[r],vl,vr); res += get_cnt(u->r,mid+1,rx,l-b[l-1],r-b[r],vl,vr); return res; } int get_cnt(int l, int r, int vl, int vr){ return get_cnt(root,0,siz,l+1,r+1,vl,vr); } int kth(node* u, int lx, int rx, int l, int r, int k){ if(lx == rx){ return lx; } int mid = (lx+rx) >> 1; auto &b = u->b; int cnt = b[r]-b[l-1]; if(k <= cnt){ return kth(u->l,lx,mid,b[l-1]+1,b[r],k); } else{ return kth(u->r,mid+1,rx,l-b[l-1],r-b[r],k-cnt); } } int kth(int l, int r, int k){ return kth(root,0,siz,l+1,r+1,k); } int kth_largest(int l, int r, int k){ int len = r-l+1; if(k > len) return -1; return kth(l,r,len-k+1); } }; int n; vector<pii> a; wavelet_tree wt; vector<int> lb; void init(int n_, int A[], int B[]) { n = n_; rep(i,n) a.pb({A[i],B[i]}); sort(all(a)); lb = vector<int>(n+5,n); rep(i,n) amin(lb[a[i].ff],i); rev(i,n,0) amin(lb[i],lb[i+1]); vector<int> b; rep(i,n) b.pb(a[i].ss); wt.build(b,n+1); } int can(int m, int K[]) { int sum = 0; rep(i,m){ sum += K[i]; if(sum > n){ return 0; } } sort(K,K+m); vector<pii> c; c.pb({K[0],K[0]}); rep1(i,m-1){ if(K[i] == K[i-1]){ c.back().ss += K[i]; } else{ c.pb({K[i],K[i]}); } } c.insert(c.begin(),{0,0}); m = sz(c); vector<int> dp(m,inf1); dp[0] = 0; auto get = [&](int mnl, int mxl, int mnr){ return wt.get_cnt(lb[mnl],lb[mxl+1]-1,mnr,n); }; auto first_bad = [&](int i, int j){ int diff = dp[j]-dp[i]; int l = lb[c[i].ff+1]; int r = lb[c[j].ff+1]-1; int val = wt.kth_largest(l,r,diff); if(val == -1) return -1; if(diff < wt.get_cnt(l,r,val,n)) val++; int pos = upper_bound(all(c),make_pair(val,-1))-c.begin(); return pos; // int x = i, y = j; // int val = dp[y]-dp[x]; // int lo = 0, hi = m-1; // int pos = -1; // while(lo <= hi){ // int mid = (lo+hi) >> 1; // if(val >= get(c[x].ff+1,c[y].ff,c[mid].ff)){ // pos = mid; // hi = mid-1; // } // else{ // lo = mid+1; // } // } // return pos; }; set<int> st; st.insert(0); vector<int> leave[m+5]; // auto odp = dp; rep1(i,m-1){ while(!leave[i].empty()){ int j = leave[i].back(); leave[i].pop_back(); if(!st.count(j)) conts; st.erase(j); auto it = st.upper_bound(j); if(it != st.end() and it != st.begin()){ int x = *prev(it), y = *it; int pos = first_bad(x,y); if(pos != -1){ amax(pos,i); assert(pos <= m); leave[pos].pb(y); } } } { int j = *st.rbegin(); dp[i] = dp[j]+get(c[j].ff+1,c[i].ff,c[i].ff)-c[i].ss; } st.insert(i); auto it = st.find(i); { int x = *prev(it), y = *it; int pos = first_bad(x,y); if(pos != -1){ amax(pos,i+1); leave[pos].pb(y); } } // rep(j,i){ // amin(odp[i],odp[j]+get(c[j].ff+1,c[i].ff,c[i].ff)); // } // odp[i] -= c[i].ss; } // debug(dp); // debug(odp); // assert(dp == odp); int mn = *min_element(all(dp)); return mn >= 0; }

Compilation message (stderr)

teams.cpp: In lambda function:
teams.cpp:217:56: warning: conversion from '__gnu_cxx::__normal_iterator<std::pair<int, int>*, std::vector<std::pair<int, int> > >::difference_type' {aka 'long int'} to 'int' may change value [-Wconversion]
  217 |         int pos = upper_bound(all(c),make_pair(val,-1))-c.begin();
      |                   ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...