#include "bits/stdc++.h"
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
#define all(x) x.begin(), x.end()
#define sz(x) ((int) ((x).size()))
#define pb push_back
#define F first
#define S second
#define FIO ios_base::sync_with_stdio(false); cin.tie(0)
const int N = 200100;
const int R1 = 510;
const int R2 = 25010;
int n, r, q;
int h[N];
int par[N];
vector<int> chi[N];
vector<int> emp[R2];
vector<pii> empIn[R2];
vector<pii> empOut[R2];
int in[N];
int out[N];
int timer;
int cnt[R1][R1];
int temp[R1];
vector<int>* dfs(int node) {
vector<int>* cur = nullptr;
for (int u : chi[node]) {
vector<int>* v = dfs(u);
if (cur == nullptr) {
cur = v;
} else {
for (int i = 1; i < R1; i++) {
(*cur)[i] += (*v)[i];
}
delete(v);
}
}
if (sz(chi[node]) == 0) {
cur = new vector<int>(R1);
}
for (int i = 1; i < R1; i++) {
cnt[h[node]][i] += (*cur)[i];
}
(*cur)[h[node]]++;
return cur;
}
void dfs2(int node) {
timer++;
in[node] = timer;
for (int u : chi[node]) {
dfs2(u);
}
out[node] = timer;
}
int main() {
FIO;
cin >> n >> r >> q;
for (int i = 1; i <= n; i++) {
if (i != 1) {
cin >> par[i];
chi[par[i]].pb(i);
}
cin >> h[i];
emp[h[i]].pb(i);
}
if (r <= 500) {
dfs(1);
for (int i = 0; i < q; i++) {
int r1, r2;
cin >> r1 >> r2;
int ans = cnt[r1][r2];
cout << ans << endl;
}
} else {
dfs2(1);
for (int i = 1; i <= n; i++) {
empIn[h[i]].pb({in[i], i});
empOut[h[i]].pb({out[i], i});
}
for (int i = 1; i <= r; i++) {
sort(all(empIn[i]));
sort(all(empOut[i]));
}
for (int i = 0; i < q; i++) {
int r1, r2;
cin >> r1 >> r2;
int ptrR2 = 0;
int ptrIn = 0;
int ptrOut = 0;
int sum = 0;
while (ptrIn < sz(empIn[h[r1]]) or ptrOut < sz(empOut[h[r1]])) {
int curIn = INT_MAX;
int curOut = INT_MAX;
int curR2 = INT_MAX;
if (ptrIn < sz(empIn[h[r1]])) curIn = empIn[h[r1]][ptrIn].F - 1;
if (ptrOut < sz(empOut[h[r1]])) curOut = empOut[h[r1]][ptrOut].F;
if (ptrR2 < sz(empIn[r2])) curR2 = empIn[r2][ptrR2].F;
int cur = min({curIn, curOut, curR2});
while (ptrR2 < sz(empIn[r2]) and cur >= empIn[r2][ptrR2].F) {
sum++;
ptrR2++;
}
while (ptrIn < sz(empIn[h[r1]]) and cur >= empIn[h[r1]][ptrIn].F - 1) {
temp[empIn[h[r1]][ptrIn].S] -= sum;
ptrIn++;
}
while (ptrOut < sz(empOut[h[r1]]) and cur >= empOut[h[r1]][ptrOut].F) {
temp[empOut[h[r1]][ptrIn].S] += sum;
ptrOut++;
}
}
int ans = 0;
for (int j = 0; j < R1; j++) {
ans += temp[j];
temp[j] = 0;
}
cout << ans << endl;
}
}
return 0;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |