Submission #1146607

#TimeUsernameProblemLanguageResultExecution timeMemory
1146607mannshah1211Regions (IOI09_regions)C++20
100 / 100
3239 ms102204 KiB
#include <bits/stdc++.h>

using namespace std;

#ifdef LOCAL
#include "deb.h"
#else
#define debug(...) 
#endif

const int B = 450;

void solve() {
  int n, r, q;
  cin >> n >> r >> q;
  vector<int> p(n), region(n);
  cin >> region[0];
  --region[0];
  vector<vector<int>> employees(r), g(n);
  employees[region[0]].push_back(0);
  for (int i = 1; i < n; i++) {
    cin >> p[i] >> region[i];
    --p[i]; --region[i];
    employees[region[i]].push_back(i);
    g[p[i]].push_back(i);
    g[i].push_back(p[i]);
  }
  vector<int> big;
  vector<int> index(r, -1);
  for (int i = 0; i < r; i++) {
    if (employees[i].size() >= B) {
      index[i] = big.size();
      big.push_back(i);
    }
  }
  vector<vector<long long>> prec(big.size(), vector<long long>(big.size()));
  vector<int> freq(big.size());
  vector<vector<long long>> anc(n, vector<long long>(big.size()));
  auto Fill = [&](auto&& self, int v, int pr) -> void {
    if (index[region[v]] != -1) {
      for (int j = 0; j < big.size(); j++) {
        prec[j][index[region[v]]] += freq[j];
      }
      ++freq[index[region[v]]];
    }
    for (int j = 0; j < big.size(); j++) {
      anc[v][j] += freq[j];
    }
    for (int u : g[v]) {
      if (u != pr) {
        self(self, u, v);
      }
    }
    if (index[region[v]] != -1) {
      --freq[index[region[v]]];
    }
  };
  Fill(Fill, 0, -1);
  vector<int> tin(n), tout(n);
  int timer = 0;
  auto Dfs = [&](auto&& self, int v, int pr) -> void {
    tin[v] = timer++;
    for (int u : g[v]) {
      if (u != pr) {
        self(self, u, v);
      }
    }
    tout[v] = timer;
  };
  Dfs(Dfs, 0, -1);
  vector<vector<int>> at(r, {-1});
  for (int i = 0; i < n; i++) {
    at[region[i]].push_back(tin[i]);
  }
  for (int i = 0; i < r; i++) {
    sort(at[i].begin(), at[i].end());
    at[i].push_back(n);
  }
  auto in_subtree = [&](int k, int v) {
    int l = tin[v], r = tout[v];
    return lower_bound(at[k].begin(), at[k].end(), r) - lower_bound(at[k].begin(), at[k].end(), l);
  };
  while (q--) {
    int r1, r2;
    cin >> r1 >> r2;
    --r1; --r2;
    if (employees[r1].size() >= B) {
      if (employees[r2].size() >= B) {
        cout << prec[index[r1]][index[r2]] << '\n';
        continue;
      } else {
        long long ans = 0;
        for (int e : employees[r2]) {
          ans += anc[e][index[r1]];
        }
        cout << ans << '\n';
        continue;
      }
    }
    long long ans = 0;
    for (int e : employees[r1]) {
      ans += in_subtree(r2, e);
    }
    cout << ans << '\n';
    fflush(stdout);
  } 
}

int main() {
  int t = 1;
  // cin >> t;
  while (t--) {
    solve();
  }
  return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...