This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
#define f first
#define s second
#define pb push_back
#define ep emplace
#define eb emplace_back
#define lb lower_bound
#define ub upper_bound
#define all(x) x.begin(), x.end()
#define rall(x) x.rbegin(), x.rend()
#define uniquev(v) sort(all(v)), (v).resize(unique(all(v)) - (v).begin())
#define mem(f,x) memset(f , x , sizeof(f))
#define sz(x) (int)(x).size()
#define __lcm(a, b) (1ll * ((a) / __gcd((a), (b))) * (b))
#define mxx *max_element
#define mnn *min_element
#define cntbit(x) __builtin_popcountll(x)
#define len(x) (int)(x.length())
const int N = 2e5 + 100;
const int lim = 500;
vector <int> lst[N], ed[N], all;
map <int, ll> ans[N], save[N];
int c[N], in[N], ou[N], timer = 1;
void dfs(int u) {
save[u][c[u]]++;
in[u] = ++timer;
for (auto x : ed[u]) {
dfs(x);
if (sz(save[u]) < sz(save[x])) {
swap(save[u], save[x]);
}
for (auto t : save[x]) {
save[u][t.f] += t.s;
}
}
for (auto x : all) {
if (save[u].find(x) == save[u].end())
continue;
ans[c[u]][x] += save[u][x];
}
if (sz(lst[c[u]]) >= lim) {
for (auto x : save[u]) {
ans[c[u]][x.f] += x.s;
}
}
ou[u] = timer;
}
int f[N];
void upd(int i, int x) {
for (; i < N; i += i & (-i))
f[i] += x;
}
int get(int i) {
int s = 0;
for (; i; i -= i & (-i))
s += f[i];
return s;
}
int range(int l, int r) {
return get(r) - get(l - 1);
}
int main() {
int n, r, q;
cin >> n >> r >> q;
cin >> c[1];
for (int i = 2; i <= n; i++) {
int x;
cin >> x;
cin >> c[i];
ed[x].pb(i);
}
for (int i = 1; i <= n; i++) {
lst[c[i]].pb(i);
}
for (int i = 1; i <= r; i++) {
if (sz(lst[i]) >= lim)
all.pb(i);
}
dfs(1);
for (int i = 1; i <= q; i++) {
int r1, r2;
cin >> r1 >> r2;
if (max(sz(lst[r1]), sz(lst[r2])) >= lim) {
cout << ans[r1][r2] << '\n';
} else {
int cnt = 0;
for (auto x : lst[r2])
upd(in[x], 1);
for (auto x : lst[r1]) {
cnt += range(in[x], ou[x]);
}
for (auto x : lst[r2])
upd(in[x], -1);
cout << cnt << '\n';
}
}
return 0;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |