이 제출은 이전 버전의 oj.uz에서 채점하였습니다. 현재는 제출 당시와는 다른 서버에서 채점을 하기 때문에, 다시 제출하면 결과가 달라질 수도 있습니다.
#include"holiday.h"
#include <bits/stdc++.h>
using namespace std;
#define sz(x) (int)(x).size()
#define all(arr) (arr).begin(), (arr).end()
#define ll long long
#define ld long double
#define pb push_back
#define fi first
#define se second
const int N = 1e5 + 20;
struct Node{
ll sum;
int cnt;
Node* l;
Node* r;
Node(){
sum = cnt = 0, l = r = nullptr;
};
Node(ll sum_, int cnt_){
sum = sum_, cnt = cnt_, l = r = nullptr;
}
Node(Node* l_, Node* r_){
l = l_, r = r_, sum = l -> sum + r -> sum, cnt = l -> cnt + r -> cnt;
};
};
Node* root[N];
Node* build(int tl, int tr){
if (tl == tr) return new Node();
int tm = (tl + tr) / 2;
return new Node(build(tl, tm), build(tm + 1, tr));
}
Node* modify(Node* v, int tl, int tr, int pos, int val){
if (tl == tr) return new Node(v -> sum + val, v -> cnt + 1);
int tm = (tl + tr) / 2;
if (pos <= tm) return new Node(modify(v -> l, tl, tm, pos, val), v -> r);
else return new Node(v -> l, modify(v -> r, tm + 1, tr, pos, val));
}
ll query(Node* v, int tl, int tr, int& need){
if (v -> cnt <= need){
need -= v -> cnt;
return v -> sum;
}
if (tl == tr){
ll ret = need * (v -> sum / v -> cnt);
need = 0;
return ret;
}
int tm = (tl + tr) / 2;
ll ret = query(v -> r, tm + 1, tr, need);
if (need) ret += query(v -> l, tl, tm, need);
return ret;
}
int m, weight;
vector<ll> dp;
ll calc(int d, int pref){
d -= pref * weight;
return d > 0? query(root[pref], 0, m - 1, d) : 0;
}
void rec(int l, int r, int optl, int optr){
if (l > r) return;
int mid = (l + r) / 2;
int opt = optl;
ll optval = 0;
for (int i = optl; i <= optr; i++){
ll cur = calc(mid, i);
if (cur > optval) optval = cur, opt = i;
}
dp[mid] = optval;
rec(l, mid - 1, optl, opt);
rec(mid + 1, r, opt, optr);
}
vector<ll> solve(vector<int> a, int maxd, int w){
int n = sz(a);
if (!n){
return vector<ll>(maxd + 1);
}
vector<int> vals = a;
sort(all(vals));
vals.erase(unique(all(vals)), vals.end());
m = sz(vals), weight = w;
if (w == 1){
root[0] = build(0, m - 1);
for (int i = 0; i < n; i++){
if (i) root[i] = root[i - 1];
root[i] = modify(root[i], 0, m - 1, lower_bound(all(vals), a[i]) - vals.begin(), a[i]);
}
}
dp.resize(maxd + 1);
rec(0, maxd, 0, n - 1);
return dp;
}
long long int findMaxAttraction(int n, int start, int d, int a[]) {
if (!d) return 0;
vector<int> lft, rgt;
for (int i = 0; i < start; i++) lft.pb(a[i]);
for (int i = start + 1; i < n; i++) rgt.pb(a[i]);
reverse(all(lft));
auto left_noret = solve(lft, d, 1), left_ret = solve(lft, d, 2);
auto right_noret = solve(rgt, d, 1), right_ret = solve(rgt, d, 2);
ll ans = 0;
for (int takeleft = 1; takeleft <= d; takeleft++){
ans = max(ans, left_noret[takeleft - 1]);
}
for (int takeright = 1; takeright <= d; takeright++){
ans = max(ans, right_noret[takeright - 1]);
}
for (int dleft = 1; dleft < d; dleft++){
int dright = d - dleft;
if (dleft >= 2) ans = max(ans, left_ret[dleft - 2] + right_noret[dright - 1]);
if (dright >= 2) ans = max(ans, left_noret[dleft - 1] + right_ret[dright - 2]);
}
ll ret = ans;
d--;
ans = 0;
for (int takeleft = 1; takeleft <= d; takeleft++){
ans = max(ans, left_noret[takeleft - 1]);
}
for (int takeright = 1; takeright <= d; takeright++){
ans = max(ans, right_noret[takeright - 1]);
}
for (int dleft = 1; dleft < d; dleft++){
int dright = d - dleft;
if (dleft >= 2) ans = max(ans, left_ret[dleft - 2] + right_noret[dright - 1]);
if (dright >= 2) ans = max(ans, left_noret[dleft - 1] + right_ret[dright - 2]);
}
return max(ret, ans + a[start]);
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |