답안 #535397

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
535397 2022-03-10T08:21:54 Z chenwz Regions (IOI09_regions) C++11
100 / 100
2276 ms 43420 KB
// IOI2009 – Regions
#include <bits/stdc++.h>
using namespace std;
using VI = vector<int>;
using IP = pair<int, int>;
using LL = long long;
struct Emp {     // Employee
  int tin, reg;  // tin(dfs时间戳), reg(区域)
  VI ch;         // 被指导的
};
struct Reg {
  VI ids;  // 其中的Employee
  /* Sorted of intervals with the same nesting level. Each pair is
   * (ID, depth) where ID is the left end-point of the interval (inclusive).
   * The right end-point is implicit from the following interval.
   */
  vector<IP> ranges;
  int depth; /* Working depth during the DFS. */
};
vector<Emp> ES;
vector<Reg> RS;
/* Does a pre-order walk over the subtree rooted at root. id_pool contains
 * the next unused employee ID, and on return it will be updated to again
 * be the next available ID.
 * This procedure builds the RS arrays, after which the tree is no longer
 * needed.
 */
void dfs(int u, int &timer) {
  auto &r = RS[ES[u].reg];
  r.ids.push_back(timer++);
  /* Depth changed, so after this point we need a new range */
  r.ranges.push_back({timer, ++r.depth});
  for (int v : ES[u].ch) dfs(v, timer);
  /* Undo the depth change, and start another interval after the last managee.
   */
  r.ranges.push_back({timer, --r.depth});
}
/* Query in O(R2 log R1) time, by counting for each employee in r2. */
LL query_by_id(const Reg &r1, const Reg &r2) {
  LL ans = 0;
  for (int pos : r2.ids) {
    /* Find the first range that starts at pos or later. This will
     * actually be the range after the one we want.
     */
    auto it = lower_bound(r1.ranges.begin(), r1.ranges.end(),
                          make_pair(pos, INT_MAX));
    if (it != r1.ranges.begin()) ans += prev(it)->second;
  }
  return ans;
}

/* Query in O(R1 log R2) time, by counting for each employee in r1 */
LL query_by_range(const Reg &r1, const Reg &r2) {
  LL ans = 0;
  for (size_t i = 0; i + 1 < r1.ranges.size(); i++) {
    int p1 = r1.ranges[i].first, p2 = r1.ranges[i + 1].first;
    /* Each employee from r2 in [p1, p2) has depth managers
     * from r1. Find the intersections of [p1, p2) with the
     * employee list for r2.
     */
    auto it1 = lower_bound(r2.ids.begin(), r2.ids.end(), p1),
         it2 = lower_bound(r2.ids.begin(), r2.ids.end(), p2);
    ans += r1.ranges[i].second * (it2 - it1);
  }
  return ans;
}

// 针对r2中的每个id,线性在r1中查找,时间O(|r1|+|r2|)
LL query_stitch(const Reg &r1, const Reg &r2) {
  LL ans = 0;
  auto st = r2.ids.begin();
  if (r1.ranges.empty()) return ans;
  while (st != r2.ids.end() && *st < r1.ranges[0].first) st++;
  for (size_t i = 0; i + 1 < r1.ranges.size() && st != r2.ids.end(); i++) {
    /* Find the end of the section of employees from this range */
    // auto id_bak = st;
    // while (st != r2.ids.end() && *st < r1.ranges[i + 1].first) st++;
    // ans += r1.ranges[i].second * (st - id_bak);

    auto ed = st;
    while (ed != r2.ids.end() && *ed < r1.ranges[i + 1].first) ed++;
    ans += r1.ranges[i].second * (ed - st), st = ed;
  }
  return ans;
}
LL solve(int r1, int r2) {
  static map<IP, LL> cache;
  IP key(r1, r2);
  if (cache.count(key)) return cache[key];
  const Reg &reg1 = RS[r1], &reg2 = RS[r2];
  int sz1 = reg1.ids.size(), sz2 = reg2.ids.size();
  int costs[3] = {sz1 * ((int)log2(sz2) + 2) * 5,
                  sz2 * ((int)log2(sz1) + 2) * 5, sz1 + sz2};
  int k = min_element(costs, costs + 3) - costs;
  if (k == 0) return cache[key] = query_by_range(reg1, reg2);
  if (k == 1) return cache[key] = query_by_id(reg1, reg2);
  return cache[key] = query_stitch(reg1, reg2);
}
int main() {
  ios::sync_with_stdio(false), cin.tie(0);
  int N, R, Q, timer = 0;
  cin >> N >> R >> Q, ES.resize(N), RS.resize(R);
  cin >> ES[0].reg, --ES[0].reg;
  for (int i = 1, fa; i < N; i++)
    cin >> fa >> ES[i].reg, --ES[i].reg, ES[fa - 1].ch.push_back(i);
  dfs(0, timer);
  // for(int i = 0; i < R; i++){
  //   if(RS[i].ids.empty()) continue;
  //   printf("reg %d: ", i + 1);
  //   for(auto p : RS[i].ranges){
  //     printf("%d(%d),  ", p.first, p.second);
  //   }
  //   puts("");
  // }
  for (int q = 0, r1, r2; q < Q; q++)
    cin >> r1 >> r2, cout << solve(r1 - 1, r2 - 1) << endl;
  return 0;
}
# 결과 실행 시간 메모리 Grader output
1 Correct 0 ms 308 KB Output is correct
2 Correct 1 ms 208 KB Output is correct
3 Correct 3 ms 208 KB Output is correct
4 Correct 5 ms 348 KB Output is correct
5 Correct 8 ms 336 KB Output is correct
6 Correct 15 ms 488 KB Output is correct
7 Correct 23 ms 564 KB Output is correct
8 Correct 36 ms 692 KB Output is correct
9 Correct 23 ms 1384 KB Output is correct
10 Correct 69 ms 1692 KB Output is correct
11 Correct 100 ms 2320 KB Output is correct
12 Correct 124 ms 3236 KB Output is correct
13 Correct 133 ms 3052 KB Output is correct
14 Correct 141 ms 3796 KB Output is correct
15 Correct 159 ms 7708 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 747 ms 9256 KB Output is correct
2 Correct 765 ms 8340 KB Output is correct
3 Correct 1166 ms 14252 KB Output is correct
4 Correct 194 ms 4840 KB Output is correct
5 Correct 306 ms 7636 KB Output is correct
6 Correct 530 ms 7452 KB Output is correct
7 Correct 729 ms 8568 KB Output is correct
8 Correct 886 ms 18184 KB Output is correct
9 Correct 1350 ms 24724 KB Output is correct
10 Correct 1749 ms 33056 KB Output is correct
11 Correct 2193 ms 29896 KB Output is correct
12 Correct 1008 ms 23292 KB Output is correct
13 Correct 1346 ms 25884 KB Output is correct
14 Correct 1510 ms 27584 KB Output is correct
15 Correct 1884 ms 35756 KB Output is correct
16 Correct 2276 ms 43420 KB Output is correct
17 Correct 2055 ms 41440 KB Output is correct