#include <bits/stdc++.h>
using namespace std;
#pragma GCC optimize("O3")
#pragma GCC optimize("Ofast")
#pragma optimize("unroll-loops")
#define fast ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define all(x) x.begin(), x.end()
#define rll(x) x.rbegin(), x.rend()
#define COMP(x) x.erase(unique(all(x)), x.end())
#define MOD 1000000007
#define MOD2 998244353
#define sz(x) (ll)x.size()
typedef __int128_t lll;
typedef long long ll;
typedef unsigned long long ull;
typedef long double ld;
typedef pair<ll,ll> pll;
typedef pair<ll, pll> PP;
const ll Lnf = 2e18;
ll n, r, q;
int in[250001], out[250001];
int cnt;
vector<int> adj[252525];
int res[250001][501]; //각 정점->루트보다 큰 거 답(올라가면서)
int res2[250001][501]; //(내려가면서)
int regions[250001];
int conv[250001]; //region convert
int sqn;
int ans[501][501];
void dfs(ll x, ll prv=-1){
in[x] = ++cnt;
for(auto next : adj[x])if(prv!=next){
for(int i = 1 ; i <= sqn ; i++){
res[next][i] = res[x][i] + (regions[x] == i);
if(regions[next]<=sqn)ans[i][regions[next]] += res[next][i];
}
dfs(next,x);
for(int i = 1 ; i <= sqn ; i++)res2[x][i] += res2[next][i] + (regions[next] == i);
}
out[x] = cnt;
}
vector<int> reg[252525];
vector<array<int,2>> event[252525];
int main(){
fast;
cin>>n>>r>>q;
vector<array<ll,2>> cnt(r+1);
cin>>regions[1]; cnt[regions[1]][0]++;
for(int i = 2 ; i <= n ; i++){
ll x; cin>>x; adj[x].push_back(i);
cin>>regions[i]; cnt[regions[i]][0]++;
}
for(int i = 1 ; i <= r ; i++)cnt[i][1] = i;
sort(rll(cnt));
for(int i = 0 ; i < r ; i++){
conv[cnt[i][1]] = i+1;
}
for(int i = 1 ; i <= n ; i++)regions[i] = conv[regions[i]], reg[regions[i]].push_back(i);
ll t = sqrt(n);
for(int i = 1 ; i <= r ; i++){
if(t > sz(reg[i]))break;
sqn=i;
}
dfs(1);
for(int i = 1 ; i <= r ; i++){
sort(all(reg[i]), [&](int a, int b){ return in[a] < in[b]; });
for(auto j : reg[i])event[i].push_back({in[j]+1,1}), event[i].push_back({out[j]+1,-1});
sort(all(event[i]));
}
while(q--){
ll r1, r2; cin>>r1>>r2; r1 = conv[r1], r2 = conv[r2];
if(r1<=sqn and r2<=sqn){ //up up
cout<<ans[r1][r2]<<endl;
}
else if(r1<=sqn){
ll s = 0;
for(auto i : reg[r2])s += res[i][r1];
cout<<s<<endl;
}
else if(r2<=sqn){
ll s=0;
for(auto i : reg[r1])s += res2[i][r2];
cout<<s<<endl;
}
else{
ll s = 0, sum = 0;
for(int i = 0, j = 0 ; i < sz(reg[r2]) ; i++){
while(j<sz(event[r1]) and event[r1][j][0] <= in[reg[r2][i]]){
sum += event[r1][j][1];
j++;
}
s += sum;
}
cout<<s<<endl;
}
}
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |