#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
template<class T> bool cmin(T &i, T j) { return i > j ? i=j,true:false; }
template<class T> bool cmax(T &i, T j) { return i < j ? i=j,true:false; }
mt19937 mrand(chrono::steady_clock::now().time_since_epoch().count());
uniform_int_distribution<int> ui(0, 1 << 30);
int rng() {
return ui(mrand);
}
const int N=200001;
const int B=50;
int pref[B][N];
ll res[B][N];
int order[N];
vector<int> nodes[N];
vector<int> adj[N];
int n,R,q;
int regions[N];
int timer=0,tin[N],tout[N];
void dfs(int u=0) {
tin[u]=timer++;
nodes[regions[u]].push_back(u);
for (int &v: adj[u]) {
dfs(v);
}
tout[u]=timer-1;
}
int dp[N];
void dfs2(int u=0) {
for (int &v: adj[u]) {
dp[v]+=dp[u];
dfs2(v);
}
}
namespace atcoder {
template <class T> struct fenwick_tree {
public:
fenwick_tree() : _n(0) {}
fenwick_tree(int n) : _n(n), data(n) {}
void add(int p, T x) {
assert(0 <= p && p < _n);
p++;
while (p <= _n) {
data[p - 1] += x;
p += p & -p;
}
}
T sum(int l, int r) {
assert(0 <= l && l <= r && r <= _n);
return sum(r) - sum(l);
}
private:
int _n;
vector<T> data;
T sum(int r) {
T s = 0;
while (r > 0) {
s += data[r - 1];
r -= r & -r;
}
return s;
}
};
} // namespace atcoder
using namespace atcoder;
fenwick_tree<int> ft(N);
int main() {
cin >> n >> R >> q;
cin >> regions[0];
regions[0]--;
for (int i=1;i<n;i++) {
int p;
cin >> p;
adj[--p].push_back(i);
cin >> regions[i];
--regions[i];
}
dfs();
iota(order,order+R,0);
sort(order,order+R,[&](int r1,int r2){
return nodes[r1].size()>nodes[r2].size();
});
for (int i=0;i<min(R,B);i++) {
int r=order[i];
for (int &u: nodes[r])
pref[i][tin[u]]++;
for (int j=1;j<n;j++)
pref[i][j]+=pref[i][j-1];
for (int &u: nodes[r])
dp[u]++;
dfs2();
for (int j=0;j<n;j++) if (regions[j]!=r) {
res[i][regions[j]]+=dp[j];
}
for (int j=0;j<n;j++)
dp[j]=0;
}
auto solve = [&](int r1,int r2) {
for (int i=0;i<B;i++) if (order[i]==r1) {
return res[i][r2];
}
for (int i=0;i<B;i++) if (order[i]==r2) {
ll ans=0;
for (int &u: nodes[r1]) {
ans+=pref[i][tout[u]]-(tin[u]?pref[i][tin[u]-1]:0);
}
return ans;
}
for (int &v: nodes[r2])
ft.add(tin[v],1);
ll ans=0;
for (int &u: nodes[r1])
ans+=ft.sum(tin[u],tout[u]+1);
for (int &v: nodes[r2])
ft.add(tin[v],-1);
return ans;
};
for (int r1,r2;q--;) {
cin >> r1 >> r2;
--r1,--r2;
cout << solve(r1,r2) << endl;
// solve naive
}
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |