Submission #832885

#TimeUsernameProblemLanguageResultExecution timeMemory
832885ikaurovHoliday (IOI14_holiday)C++17
47 / 100
155 ms65536 KiB
#include"holiday.h" #include <bits/stdc++.h> using namespace std; #define sz(x) (int)(x).size() #define all(arr) (arr).begin(), (arr).end() #define ll long long #define ld long double #define pb push_back #define fi first #define se second const int N = 1e5 + 20; struct Node{ ll sum; int cnt; Node* l; Node* r; Node(){ sum = cnt = 0, l = r = nullptr; }; Node(ll sum_, int cnt_){ sum = sum_, cnt = cnt_, l = r = nullptr; } Node(Node* l_, Node* r_){ l = l_, r = r_, sum = l -> sum + r -> sum, cnt = l -> cnt + r -> cnt; }; }; Node* root[N]; Node* build(int tl, int tr){ if (tl == tr) return new Node(); int tm = (tl + tr) / 2; return new Node(build(tl, tm), build(tm + 1, tr)); } Node* modify(Node* v, int tl, int tr, int pos, int val){ if (tl == tr) return new Node(v -> sum + val, v -> cnt + 1); int tm = (tl + tr) / 2; if (pos <= tm) return new Node(modify(v -> l, tl, tm, pos, val), v -> r); else return new Node(v -> l, modify(v -> r, tm + 1, tr, pos, val)); } ll query(Node* v, int tl, int tr, int& need){ if (v -> cnt <= need){ need -= v -> cnt; return v -> sum; } if (tl == tr){ ll ret = need * (v -> sum / v -> cnt); need = 0; return ret; } int tm = (tl + tr) / 2; ll ret = query(v -> r, tm + 1, tr, need); if (need) ret += query(v -> l, tl, tm, need); return ret; } int m, weight; vector<ll> dp; ll calc(int d, int pref){ d -= pref * weight; return d > 0? query(root[pref], 0, m - 1, d) : 0; } void rec(int l, int r, int optl, int optr){ if (l > r) return; int mid = (l + r) / 2; int opt = optl; ll optval = 0; for (int i = optl; i <= optr; i++){ ll cur = calc(mid, i); if (cur > optval) optval = cur, opt = i; } dp[mid] = optval; rec(l, mid - 1, optl, opt); rec(mid + 1, r, opt, optr); } vector<ll> solve(vector<int> a, int maxd, int w){ int n = sz(a); if (!n){ return vector<ll>(maxd + 1); } vector<int> vals = a; sort(all(vals)); vals.erase(unique(all(vals)), vals.end()); m = sz(vals), weight = w; if (w == 1){ root[0] = build(0, m - 1); for (int i = 0; i < n; i++){ if (i) root[i] = root[i - 1]; root[i] = modify(root[i], 0, m - 1, lower_bound(all(vals), a[i]) - vals.begin(), a[i]); } } dp.resize(maxd + 1); rec(0, maxd, 0, n - 1); return dp; } long long int findMaxAttraction(int n, int start, int d, int a[]) { if (!d) return 0; vector<int> lft, rgt; for (int i = 0; i < start; i++) lft.pb(a[i]); for (int i = start + 1; i < n; i++) rgt.pb(a[i]); reverse(all(lft)); auto left_noret = solve(lft, d, 1), left_ret = solve(lft, d, 2); auto right_noret = solve(rgt, d, 1), right_ret = solve(rgt, d, 2); ll ans = 0; for (int takeleft = 1; takeleft <= d; takeleft++){ ans = max(ans, left_noret[takeleft - 1]); } for (int takeright = 1; takeright <= d; takeright++){ ans = max(ans, right_noret[takeright - 1]); } for (int dleft = 1; dleft < d; dleft++){ int dright = d - dleft; if (dleft >= 2) ans = max(ans, left_ret[dleft - 2] + right_noret[dright - 1]); if (dright >= 2) ans = max(ans, left_noret[dleft - 1] + right_ret[dright - 2]); } ll ret = ans; d--; ans = 0; for (int takeleft = 1; takeleft <= d; takeleft++){ ans = max(ans, left_noret[takeleft - 1]); } for (int takeright = 1; takeright <= d; takeright++){ ans = max(ans, right_noret[takeright - 1]); } for (int dleft = 1; dleft < d; dleft++){ int dright = d - dleft; if (dleft >= 2) ans = max(ans, left_ret[dleft - 2] + right_noret[dright - 1]); if (dright >= 2) ans = max(ans, left_noret[dleft - 1] + right_ret[dright - 2]); } return max(ret, ans + a[start]); }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...