Submission #681583

#TimeUsernameProblemLanguageResultExecution timeMemory
681583puppyRegions (IOI09_regions)C++17
70 / 100
4547 ms131072 KiB
#include <iostream>
#include <vector>
#include <utility>
#include <map>
#include <cmath>
#include <algorithm>
using namespace std;
int N, SQ;
int R, Q;
int cnt[25005];
int cpx[25005];
int region[200005], par[200005];
int s[200005], e[200005];
int dp[200005][170];
vector<int> g[200005];
vector<int> exceed;
vector<int> ett;
void dfs(int v)
{
    s[v] = (int)ett.size() + 1;
    ett.push_back(v);
    for (auto &i:g[v]) dfs(i);
    e[v] = (int)ett.size() - 1;
}
void dfs2(int v, int p)
{
    int loc = -1;
    if (cnt[region[v]] > 6 * SQ) loc = cpx[region[v]];
    for (int i = 0; i < (int)exceed.size(); i++) {
        dp[v][i] = dp[p][i] + (i == loc);
    }
    for (auto &i:g[v]) dfs2(i, v);
}
vector<int> pos[25005];
vector<int> member[25005];
int main()
{
    cin >> N >> R >> Q;
    SQ = sqrt(N);
    fill(cpx, cpx + 25005, -1);
    for (int i = 1; i <= N; i++) {
        if (i >= 2) cin >> par[i], g[par[i]].push_back(i);
        cin >> region[i];
        member[region[i]].push_back(i);
        cnt[region[i]]++;
    }
    dfs(1);
    for (int i = 0; i < N; i++) {
        pos[region[ett[i]]].push_back(i);
    }
    for (int i = 1; i <= R; i++) {
        if (cnt[i] > 6 * SQ) exceed.push_back(i);
    }
    for (int k = 0; k < (int)exceed.size(); k++) cpx[exceed[k]] = k;
    dfs2(1, 0);
    map<pair<int, int>, int> res;
    while (Q--) {
        int r1, r2; cin >> r1 >> r2;
        int ans = 0;
        if (cnt[r1] <= 6 * SQ) {
            for (int &i:member[r1]) {
                //r2 색 중 i의 서브트리 내부에들어오는 것 개수
                int st = lower_bound(pos[r2].begin(), pos[r2].end(), s[i]) - pos[r2].begin();
                int en = upper_bound(pos[r2].begin(), pos[r2].end(), e[i]) - pos[r2].begin();
                --en;
                ans += (en - st + 1);
            }
        }
        else {
            bool flag = false;
            if (cnt[r2] > SQ) {
                int tmp = res[make_pair(r1, r2)];
                if (tmp > 0) ans = tmp - 1, flag = true;
            }
            if (!flag) {
                for (int &i:member[r2]) {
                    ans += dp[par[i]][cpx[r1]];
                }
                if (cnt[r2] > SQ) res[make_pair(r1, r2)] = ans + 1;
            }
        }
        cout << ans << '\n';
        cout.flush();
    }
    return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...