Submission #535015

#TimeUsernameProblemLanguageResultExecution timeMemory
535015chenwzRegions (IOI09_regions)C++11
100 / 100
2339 ms45032 KiB
#include <bits/stdc++.h>
using namespace std;
using VI = vector<int>;
using IP = pair<int, int>;
using LL = long long;
struct Emp { // Employee
  int tin, reg;     // tin(dfs时间戳), reg(Region Id), low()
  int low, high;    /* Lowest&Highest ID of managees */
  VI ch;   // children
};
struct Reg {
  VI ids;                /* Sorted list of (new) employee IDs */
  /* Sorted of intervals with the same nesting level. Each pair is
   * (ID, depth) where ID is the left end-point of the interval (inclusive).
   * The right end-point is implicit from the following interval.
   */
  vector<IP> ranges;
  int depth; /* Working depth during the DFS. */
};
int N, R, Q;
vector<Emp> nodes;
vector<Reg> RS;
/* Does a pre-order walk over the subtree rooted at root. id_pool contains
 * the next unused employee ID, and on return it will be updated to again
 * be the next available ID.
 * This procedure builds the RS arrays, after which the tree is no longer needed.
 */
void dfs(int u, int &timer) {
  auto& r = RS[nodes[u].reg];
  r.ids.push_back(timer++);
  /* Depth changed, so after this point we need a new range */
  r.ranges.push_back(make_pair(timer, ++r.depth));
  for (int v : nodes[u].ch)
    dfs(v, timer);
  /* Undo the depth change, and start another interval after the last managee. */
  r.ranges.push_back(make_pair(timer, --r.depth));
}
/* Query in O(R2 log R1) time, by counting for each employee in r2. */
LL query_by_id(const Reg &r1, const Reg &r2) {
  LL ans = 0;
  for (int pos : r2.ids) {
    /* Find the first range that starts at pos or later. This will
     * actually be the range after the one we want.
     */
    auto it = lower_bound(r1.ranges.begin(), r1.ranges.end(), make_pair(pos, INT_MAX));
    if (it != r1.ranges.begin())
      ans += prev(it)->second;
  }
  return ans;
}

/* Query in O(R1 log R2) time, by counting for each employee in r1 */
LL query_by_range(const Reg &r1, const Reg &r2) {
  LL ans = 0;
  for (size_t i = 0; i + 1 < r1.ranges.size(); i++) {
    int p1 = r1.ranges[i].first, p2 = r1.ranges[i + 1].first;
    /* Each employee from r2 in [p1, p2) has depth managers
     * from r1. Find the intersections of [p1, p2) with the
     * employee list for r2.
     */
    auto it1 = lower_bound(r2.ids.begin(), r2.ids.end(), p1),
         it2 = lower_bound(r2.ids.begin(), r2.ids.end(), p2);
    ans += r1.ranges[i].second * (it2 - it1);
  }
  return ans;
}

/* Query in O(R1 + R2) time, by counting for each employee in r1
 * but with a linear sweep instead of a binary search.
 */
LL query_stitch(const Reg &r1, const Reg &r2) {
  LL ans = 0;
  /* Find the first employee id that is in the first range */
  auto id = r2.ids.begin();
  if (r1.ranges.empty()) return ans;
  while (id != r2.ids.end() && *id < r1.ranges[0].first) id++;
  /* Iterate over the ranges as above */
  for (size_t i = 0; i + 1 < r1.ranges.size() && id != r2.ids.end(); i++) {
    /* Find the end of the section of employees from this range */
    auto id_bak = id;
    while (id != r2.ids.end() && *id < r1.ranges[i + 1].first)
      id++;
    ans += r1.ranges[i].second * (id - id_bak);
  }
  return ans;
}
LL solve(int r1, int r2) {
  static map<IP, LL> cache;
  IP key(r1, r2);
  if (cache.count(key)) return cache[key];
  const Reg &reg1 = RS[r1], &reg2 = RS[r2];
  int sz1 = reg1.ids.size(), sz2 = reg2.ids.size();
  int costs[3] = {
    sz1 * ((int)log2(sz2) + 2) * 5, sz2 * ((int)log2(sz1) + 2) * 5, sz1 + sz2
  };
  int k = min_element(costs, costs + 3) - costs;
  if (k == 0) return cache[key] = query_by_range(reg1, reg2);
  if (k == 1) return cache[key] = query_by_id(reg1, reg2);
  return cache[key] = query_stitch(reg1, reg2);
}
int main() {
  ios::sync_with_stdio(false), cin.tie(0);
  cin >> N >> R >> Q, nodes.resize(N), RS.resize(R);
  cin >> nodes[0].reg, --nodes[0].reg;
  for (int i = 1, fa; i < N; i++) {
    cin >> fa >> nodes[i].reg;
    --fa, --nodes[i].reg, nodes[fa].ch.push_back(i);
  }
  int id_pool = 0;
  dfs(0, id_pool);
  for (int q = 0, r1, r2; q < Q; q++) {
    cin >> r1 >> r2, r1--, r2--;
    cout << solve(r1, r2) << endl;
  }
  return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...