#include <bits/stdc++.h>
using namespace std;
#pragma GCC optimize("O3")
#pragma GCC optimize("Ofast")
#pragma optimize("unroll-loops")
#define fast ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define all(x) x.begin(), x.end()
#define rll(x) x.rbegin(), x.rend()
#define COMP(x) x.erase(unique(all(x)), x.end())
#define MOD 1000000007
#define MOD2 998244353
#define sz(x) (ll)x.size()
typedef __int128_t lll;
typedef long long ll;
typedef unsigned long long ull;
typedef long double ld;
typedef pair<ll,ll> pll;
typedef pair<ll, pll> PP;
const ll Lnf = 2e18;
ll n, r, q;
struct Node{
int l, r; ll v;
Node(): l(-1), r(-1), v(0){}
};
int root[250001];
struct PST{
vector<Node> tree;
ll ins(){ tree.push_back(Node()); return sz(tree)-1; }
void init(ll node, ll s, ll e){
if(s==e)return;
tree[node].l = ins(), tree[node].r = ins();
ll mid = s+e>>1;
init(tree[node].l,s,mid); init(tree[node].r,mid+1,e);
}
void upd(ll prv, ll cur, ll s, ll e, ll i, ll d){
tree[cur].v = tree[prv].v + d;
if(s==e)return;
ll mid = s+e>>1;
if(i<=mid){
tree[cur].r = tree[prv].r;
tree[cur].l = ins();
upd(tree[prv].l,tree[cur].l,s,mid,i,d);
}
else{
tree[cur].l = tree[prv].l;
tree[cur].r = ins();
upd(tree[prv].r,tree[cur].r,mid+1,e,i,d);
}
}
ll query(ll prv, ll cur, ll s, ll e, ll l, ll r){
if(e<l or r<s or cur<0)return 0;
if(l<=s and e<=r)return tree[cur].v - (prv<0?0:tree[prv].v);
ll mid = s+e>>1;
return query(tree[prv].l,tree[cur].l,s,mid,l,r) + query(tree[prv].r,tree[cur].r,mid+1,e,l,r);
}
} seg;
int in[250001], out[250001];
int cnt;
vector<int> adj[252525];
ll ans[501][501]; //루트보다 큰 것들에 대한 답
int res[250001][501]; //각 정점->루트보다 큰 거 답(올라가면서)
int regions[250001];
int conv[250001]; //region convert
int sqn;
void dfs(ll x, ll prv=-1){
in[x] = ++cnt;
root[cnt] = seg.ins();
seg.upd(root[cnt-1],root[cnt],1,r,regions[x],1);
for(auto next : adj[x])if(prv!=next){
for(int i = 1 ; i <= sqn ; i++)res[next][i] = res[x][i] + (regions[x] == i);
dfs(next,x);
}
out[x] = cnt;
}
vector<int> reg[252525];
int main(){
fast;
cin>>n>>r>>q;
vector<array<ll,2>> cnt(r+1);
cin>>regions[1]; cnt[regions[1]][0]++;
for(int i = 2 ; i <= n ; i++){
ll x; cin>>x; adj[x].push_back(i);
cin>>regions[i]; cnt[regions[i]][0]++;
}
for(int i = 1 ; i <= r ; i++)cnt[i][1] = i;
sort(rll(cnt));
for(int i = 0 ; i < r ; i++){
conv[cnt[i][1]] = i+1;
}
for(int i = 1 ; i <= n ; i++)regions[i] = conv[regions[i]], reg[regions[i]].push_back(i);
ll t = sqrt(n);
for(int i = 1 ; i <= r ; i++){
if(t > sz(reg[i]))break;
sqn=i;
}
root[0] = seg.ins(); seg.init(0,1,r);
dfs(1);
for(int i = 1 ; i <= n ; i++)if(regions[i]<=sqn and in[i] != out[i]){
for(int j = 1 ; j <= sqn ; j++){
ans[regions[i]][j] += seg.query(root[in[i]],root[out[i]],1,r,j,j);
}
}
while(q--){
ll r1, r2; cin>>r1>>r2; r1 = conv[r1], r2 = conv[r2];
if(r1>sqn){
ll s = 0;
for(auto i : reg[r1])s += seg.query(root[in[i]],root[out[i]],1,r,r2,r2);
cout<<s<<endl;
}
else if(r2<=sqn){
cout<<ans[r1][r2]<<endl;
}
else{
ll s = 0;
for(auto i : reg[r2])s += res[i][r1];
cout<<s<<endl;
}
}
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |