제출 #529915

#제출 시각아이디문제언어결과실행 시간메모리
529915msg555Aliens (IOI16_aliens)C++14
0 / 100
1 ms460 KiB
#include "aliens.h"

#include <vector>
#include <algorithm>
#include <iostream>
#include <set>
#include <cassert>

using namespace std;

long long squared(int x) {
  return 1ll * x * x;
}

template<typename T>
struct line_elem {
  T a;
  T b;
  T x_start;
  bool point_query;

  line_elem() : a(0), b(0), x_start(0), point_query(false) {}
  line_elem(T x) : a(0), b(0), x_start(x), point_query(true) {}
  line_elem(T a, T b, T x_start = 0) : a(a), b(b), x_start(x_start), point_query(false) {}

  T get(T x) const {
    return a * x + b;
  }

  bool operator<(const line_elem& rhs) const {
    if (point_query || rhs.point_query) {
      return x_start < rhs.x_start;
    }
    return make_pair(a, b) < make_pair(rhs.a, rhs.b);
  }
};

template<typename T>
struct rolling_line_hull {
  /*
  Implements a max hull, get the max ax+b for all inserted lines. 

  This requires the a's to be non-decreasing between inserts and for calls
  to get(x) to be on non-decreasing x. This removes the log(N) factor present
  in line_hull and gives amortized constant time insert/query.
   */
  rolling_line_hull() : query_index(0) {}

  void insert(T a, T b) {
    /* Insert line ax+b */
    assert(lns.empty() || lns.back().a <= a);
    if (!lns.empty() && lns.back().a == a) {
      if (b <= lns.back().b) {
        return;
      }
      lns.resize(lns.size() - 1);
    }

    auto ln = line_elem<T>(a, b);
    if (lns.empty()) {
      ln.x_start = numeric_limits<T>::min();
      lns.push_back(ln);
      return;
    }

    size_t sz = lns.size();
    ln.x_start = cross_x(lns[sz - 1], ln);
    while (sz > 1) {
      if (lns[sz - 1].x_start < ln.x_start) {
        break;
      }
      --sz;
      ln.x_start = cross_x(lns[sz - 1], ln);
    }
    lns.resize(sz);

    ln.x_start = cross_x(lns.back(), ln);
    lns.push_back(ln);
    query_index = min(query_index, lns.size() - 1);
  }

  T get(T x) {
    while (query_index + 1 < lns.size() && x >= lns[query_index + 1].x_start) {
      ++query_index;
    }
    return lns[query_index].get(x);
  }

private:
  T cross_x(const line_elem<T>& X, const line_elem<T>& Y) {
    /* Returns the first x >= to when the lines intersect */
    T da = X.a - Y.a;
    T db = Y.b - X.b;
    assert(da != 0);
    if (da < 0) {
      da = -da;
      db = -db;
    }
    if (db < 0) {
      return db / da;
    }
    return (db + da - 1) / da;
  }

  size_t query_index;
  vector<line_elem<T> > lns;
};

const size_t MAXN = 100010;

pair<int, int> A[MAXN];
long long ADJ[MAXN];

long long sample(int N, int K, long long C) {
  rolling_line_hull<long long> hl;
  long long val = 0;
  for (int i = 0; i <= N; i++) {
    val = 0;
    if (i) {
      long long x = A[i - 1].first;
      val = -hl.get(x) + squared(x) + C;
    }

    long long sadd = 1 + A[i].second - A[i].first;
    long long b = 2 * sadd;
    long long c = squared(sadd) + val - ADJ[i];
    hl.insert(-b, -c);
  }

  return val - K * C;
}

long long take_photos(int N, int M, int K, vector<int> R, vector<int> C) {
/*
  N = 100000;
  M = 100001;
  K = 50000;
  R.clear();
  C.clear();
  for (int i = 0; i < N; i++) {
    R.push_back(i);
    C.push_back(i + 1);
  }
*/

  for (int i = 0; i < N; i++) {
    int ri = R[i];
    int ci = C[i];
    if (ri > ci) swap(ri, ci);
    A[i] = make_pair(ci, ci - ri);
  }
  sort(A, A + N);

  int j = 0;
  for (int i = 0; i < N; i++) {
    while (j > 0 && A[i].second - A[i].first >= A[j - 1].second - A[j - 1].first) {
      --j;
    }
    A[j++] = A[i];
  }
  N = j;
  K = min(K, N);

  for (int i = 1; i < N; i++) {
    ADJ[i] = squared(max(0, 1 + A[i].second - A[i].first + A[i - 1].first));
  }

  long long lo = 0;
  long long hi = 1 << 16;
  long long md = hi;
  long long mdv = sample(N, K, md);
  while (true) {
    long long v = sample(N, K, 2 * hi);
    if (v <= mdv) {
      hi *= 2;
      break;
    }
    lo = hi;
    hi *= 2;
    md = hi;
    mdv = v;
  }

  while (lo + 4 <= hi) {
    long long nmd;
    if (hi - md < md - lo) {
      nmd = lo + (md - lo) / 2;
    } else {
      nmd = md + (hi - md) / 2;
    }
    long long nmdv = sample(N, K, nmd);

    if (md < nmd) {
      swap(md, nmd);
      swap(mdv, nmdv);
    }

    if (nmdv < mdv) {
      lo = nmd + 1;
    } else {
      hi = md;
      md = nmd;
      mdv = nmdv;
    }
  }

  long long result = numeric_limits<long long>::min();
  for (long long i = lo; i < hi; i++) {
    result = max(result, sample(N, K, i));
  }
  return result;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...