제출 #531671

#제출 시각아이디문제언어결과실행 시간메모리
531671erke수열 (APIO14_sequence)C++11
71 / 100
2071 ms103584 KiB
#include <bits/stdc++.h> using namespace std; typedef long long ll; typedef vector<ll> vl; const ll oo = 3e18; int n, k; vector<ll> comp; template <typename T> bool chmax(T &a, T b) { return (a < b) ? a = b, 1 : 0; } struct Line { ll a, b, id; Line() { a = 0; b = -oo; id = -1; } Line(ll _a, ll _b, ll _id): a(_a), b(_b), id(_id) {} ll operator()(ll x) { return a * x + b; } }; struct Node { Line line; Node *left, *right; Node() { left = right = nullptr; } }; struct LCT { Node *root; LCT() { root = new Node; } void update(Node* &node, ll l, ll r, Line line) { if (l > r) return; if (node == nullptr) node = new Node; if (l == r) { if (line(comp[l]) > node->line(comp[l])) node->line = line; return; } ll m = (l + r) / 2; if (line.a < node->line.a) swap(line, node->line); if (line(comp[m]) > node->line(comp[m])) { swap(line, node->line); update(node->left, l, m, line); } else update(node->right, m + 1, r, line); } pair<ll,ll> get(Node* &node, ll l, ll r, ll x, ll i) { if (l > r || node == nullptr) return {-oo, -1}; pair<ll,ll> cur = {node->line(x), node->line.id}; if (node->line.id >= i) cur = {-oo, -1}; if (l == r) return cur; ll m = (l + r) / 2; if (x < comp[m]) return max(cur, get(node->left, l, m, x, i)); else return max(cur, get(node->right, m + 1, r, x, i)); } void update(Line line) { update(root, 0, (int) comp.size() - 1, line); } pair<ll,ll> get(ll x, ll i) { return get(root, 0, (int) comp.size() - 1, x, i); } void clear(Node* &node) { if (node == nullptr) return; clear(node->left); clear(node->right); delete node; } void clear() { clear(root); root = new Node; } }; int main() { cin.tie(0)->sync_with_stdio(0); cin >> n >> k; vector<ll> a(n + 1), s(n + 1); for (int i = 1; i <= n; i++) { cin >> a[i]; s[i] = a[i] + s[i - 1]; comp.push_back(s[i]); } sort(comp.begin(), comp.end()); comp.resize(unique(comp.begin(), comp.end()) - comp.begin()); vector<LCT> lct(2); vector<vector<int>> trace(n + 1, vector<int>(k + 1)); ll ans = 0; for (int l = 0; l <= k; l++) { lct[l % 2].clear(); for (int i = 1; i <= n; i++) { ll tmp = 0; if (l >= 1) tie(tmp, trace[i][l]) = lct[(l + 1) % 2].get(s[i], i); lct[l % 2].update(Line(s[i], - (s[i] * s[i]) + tmp, i)); if (i == n && l == k) ans = tmp; } } cout << ans << '\n'; for (int i = n, j = k; j >= 1; j--) { i = trace[i][j]; cout << i << ' '; } cout << '\n'; }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...