답안 #465758

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
465758 2021-08-16T18:05:17 Z alexxela12345 사탕 분배 (IOI21_candies) C++17
3 / 100
5000 ms 29964 KB
#include "candies.h"

#include <bits/stdc++.h>
using namespace std;

typedef long long ll;

mt19937 rd(179);

struct node {
  node *l, *r;

  int x, y;
  int val;
  int sz;
  int sum;
  int max_suf;
  int min_suf;

  node(int x, int val) : x(x), val(val) {
    l = r = NULL;
    y = rd();
    sz = 1;
    sum = val;
    max_suf = max(0, val);
    min_suf = min(0, val);
  }
};

int get_sz(node *n) { return (n == NULL) ? 0 : n->sz; }

int get_sum(node *n) { return (n == NULL) ? 0 : n->sum; }

int get_max_suf(node *n) { return (n == NULL) ? 0 : n->max_suf; }
int get_min_suf(node *n) { return (n == NULL) ? 0 : n->min_suf; }

int get_ct(node *n, int i) {
  if (i == get_sz(n->l)) {
    return n->val;
  }
  if (i < get_sz(n->l)) {
    return get_ct(n->l, i);
  }
  return get_ct(n->r, i - get_sz(n->l) - 1);
}

int getInd(node *n, int x) {
  if (n == NULL)
    return 0;
  if (x < n->x)
    return getInd(n->l, x);
  if (x == n->x)
    return get_sz(n->l);
  return getInd(n->r, x) + get_sz(n->l) + 1;
}

void pull(node *n) {
  n->sz = 1 + get_sz(n->l) + get_sz(n->r);
  n->sum = n->val + get_sum(n->l) + get_sum(n->r);
  n->min_suf = min(get_min_suf(n->r), get_sum(n->r) + n->val + get_min_suf(n->l));
  n->max_suf = max(get_max_suf(n->r), get_sum(n->r) + n->val + get_max_suf(n->l));
}

pair<node *, node *> split_sz(node *n, int k) {
  if (n == NULL)
    return {n, n};
  if (k <= get_sz(n->l)) {
    auto pp = split_sz(n->l, k);
    n->l = pp.second;
    pull(n);
    return {pp.first, n};
  }
  auto pp = split_sz(n->r, k - 1 - get_sz(n->l));
  n->r = pp.first;
  pull(n);
  return {n, pp.second};
}

node *merge(node *a, node *b) {
  if (a == NULL) {
    return b;
  }
  if (b == NULL)
    return a;
  if (a->y > b->y) {
    a->r = merge(a->r, b);
    pull(a);
    return a;
  }
  b->l = merge(a, b->l);
  pull(b);
  return b;
}

struct mdata {
  node *root = NULL;

  int size();
  int get(int);
};

int mdata::size() { return get_sz(root); }

int mdata::get(int i) { return get_ct(root, i); }

void erase(mdata &mda, int i, int x) {
  int ind = getInd(mda.root, i);
  auto pp1 = split_sz(mda.root, ind);
  auto pp2 = split_sz(pp1.second, 1);
  mda.root = merge(pp1.first, pp2.second);
}

void add(mdata &mda, int i, int x) {
  int ind = getInd(mda.root, i);
  auto pp = split_sz(mda.root, ind);
  mda.root = merge(merge(pp.first, new node(i, x)), pp.second);
}

int get_fst_diff(node *n, int mx_diff, int mx, int mn) {
  if (n == NULL) {
    return 0;
  }
  int mx1 = max(max(mx, get_sum(n->r) + n->val), get_max_suf(n->r));
  int mn1 = min(min(mn, get_sum(n->r) + n->val), get_min_suf(n->r));
  if (mx1 - mn1 > mx_diff) {
    return get_fst_diff(n->r, mx_diff, mx, mn) + 1 + get_sz(n->l);
  }
  return get_fst_diff(n, mx_diff, mx1, mn1);
}

int get_fst_diff(mdata &mda, int mx_diff) {
  // returns index to first arrow after which everything contains in a segment
  // of size mx
  int cur = 0;
  int mn = 0, mx = 0;
  for (int i = mda.size() - 1; i >= 0; i--) {
    int a = mda.get(i);
    cur += a;
    mn = min(mn, cur);
    mx = max(mx, cur);
    if (mx - mn >= mx_diff) {
      return i + 1;
    }
  }
  return 0;
}

int get_min_suf(mdata &mda, int fst) {
  auto pp = split_sz(mda.root, fst);
  int ans = get_min_suf(pp.second);
  mda.root = merge(pp.first, pp.second);
  return ans;
  int cur = 0;
  int mn = 0;
  for (int i = mda.size() - 1; i >= fst; i--) {
    cur += mda.get(i);
    if (cur < mn) {
      mn = cur;
    }
  }
  return mn;
}

int get_max_suf(mdata &mda, int fst) {
  auto pp = split_sz(mda.root, fst);
  int ans = get_max_suf(pp.second);
  mda.root = merge(pp.first, pp.second);
  return ans;
  int cur = 0;
  int mn = 0;
  for (int i = mda.size() - 1; i >= fst; i--) {
    cur += mda.get(i);
    if (cur > mn) {
      mn = cur;
    }
  }
  return mn;
}

int get_ans(mdata &mda, int mx) {
  int fst = get_fst_diff(mda, mx);
  int start = 0;
  if (fst == 0) {
    start = 0;
  } else {
    int a = mda.get(fst - 1);
    if (a > 0) {
      start = mx;
    } else {
      start = 0;
    }
  }
  if (start == 0) {
    int ans = get_max_suf(mda, fst);
    return start + ans;
  } else {
    int ans = get_min_suf(mda, fst);
    return start + ans;
  }
}

vector<int> distribute_candies(vector<int> c, vector<int> l, vector<int> r,
                               vector<int> v) {
  int n = c.size();
  vector<vector<int>> add_evs(n + 1), rem_evs(n + 1);
  for (int i = 0; i < (int)l.size(); i++) {
    add_evs[l[i]].push_back(i);
    rem_evs[r[i] + 1].push_back(i);
  }
  mdata mda;
  vector<int> ans;
  ans.reserve(n);
  for (int i = 0; i < n; i++) {
    for (int el : rem_evs[i]) {
      erase(mda, el, v[el]);
    }
    for (int el : add_evs[i]) {
      add(mda, el, v[el]);
    }
    ans.push_back(get_ans(mda, c[i]));
  }
  return ans;
}
# 결과 실행 시간 메모리 Grader output
1 Correct 0 ms 204 KB Output is correct
2 Correct 0 ms 204 KB Output is correct
3 Correct 3 ms 460 KB Output is correct
4 Correct 4 ms 460 KB Output is correct
5 Correct 12 ms 636 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Execution timed out 5066 ms 29964 KB Time limit exceeded
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 1 ms 332 KB Output is correct
2 Execution timed out 5061 ms 15740 KB Time limit exceeded
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 0 ms 204 KB Output is correct
2 Correct 1 ms 332 KB Output is correct
3 Execution timed out 5026 ms 19136 KB Time limit exceeded
4 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 0 ms 204 KB Output is correct
2 Correct 0 ms 204 KB Output is correct
3 Correct 3 ms 460 KB Output is correct
4 Correct 4 ms 460 KB Output is correct
5 Correct 12 ms 636 KB Output is correct
6 Execution timed out 5066 ms 29964 KB Time limit exceeded
7 Halted 0 ms 0 KB -