제출 #931664

#제출 시각아이디문제언어결과실행 시간메모리
931664LOLOLORegions (IOI09_regions)C++17
60 / 100
8087 ms110132 KiB
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

#define           f     first
#define           s     second
#define           pb    push_back
#define           ep    emplace
#define           eb    emplace_back
#define           lb    lower_bound
#define           ub    upper_bound
#define       all(x)    x.begin(), x.end()
#define      rall(x)    x.rbegin(), x.rend()
#define   uniquev(v)    sort(all(v)), (v).resize(unique(all(v)) - (v).begin())
#define     mem(f,x)    memset(f , x , sizeof(f))
#define        sz(x)    (int)(x).size()
#define  __lcm(a, b)    (1ll * ((a) / __gcd((a), (b))) * (b))
#define          mxx    *max_element
#define          mnn    *min_element
#define    cntbit(x)    __builtin_popcountll(x)
#define       len(x)    (int)(x.length())

const int N = 2e5 + 100;
const int lim = 500;

vector <int> lst[N], ed[N], all;
map <int, ll> ans[N], save[N];
int c[N], in[N], ou[N], timer = 1;

void dfs(int u) {
    save[u][c[u]]++;
    in[u] = ++timer;
    for (auto x : ed[u]) {
        dfs(x);
        if (sz(save[u]) < sz(save[x])) {
            swap(save[u], save[x]);
        }

        for (auto t : save[x]) {
            save[u][t.f] += t.s;
        }
    }

    for (auto x : all) {
        if (save[u].find(x) == save[u].end())
            continue;
        ans[c[u]][x] += save[u][x];
    }

    if (sz(lst[c[u]]) >= lim) {
        for (auto x : save[u]) {
            ans[c[u]][x.f] += x.s;
        }
    }

    ou[u] = timer;
}

int f[N];
void upd(int i, int x) {
    for (; i < N; i += i & (-i))
        f[i] += x;
}

int get(int i) {
    int s = 0;
    for (; i; i -= i & (-i))
        s += f[i];

    return s;
}

int range(int l, int r) {
    return get(r) - get(l - 1);
}


int main() {
    int n, r, q;
    cin >> n >> r >> q;
    cin >> c[1];

    for (int i = 2; i <= n; i++) {
        int x;
        cin >> x;
        cin >> c[i];
        ed[x].pb(i);
    }

    for (int i = 1; i <= n; i++) {
        lst[c[i]].pb(i);
    }

    for (int i = 1; i <= r; i++) {
        if (sz(lst[i]) >= lim)
            all.pb(i);
    }

    dfs(1);

    for (int i = 1; i <= q; i++) {
        int r1, r2;
        cin >> r1 >> r2;
        if (max(sz(lst[r1]), sz(lst[r2])) >= lim) {
            cout << ans[r1][r2] << '\n';
        } else {
            int cnt = 0;
            for (auto x : lst[r2])
                upd(in[x], 1);


            for (auto x : lst[r1]) {
                cnt += range(in[x], ou[x]);
            }

            for (auto x : lst[r2])
                upd(in[x], -1);

            cout << cnt << '\n';
        }
    }

    return 0;
}

#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...