Submission #1301798

#TimeUsernameProblemLanguageResultExecution timeMemory
1301798tishoRegions (IOI09_regions)C++20
50 / 100
2751 ms30680 KiB
#include <bits/stdc++.h>

using namespace std;

const int MAXR = 25000;
const int MAXN = 200000;

bool big[MAXR + 5];
vector<int>v[MAXR + 5], graph[MAXN + 5];
int n, r, q, region, a[MAXN + 5], in[MAXN + 5], out[MAXN + 5], timer, ans[5000][MAXR + 5];

bool cmp(int x, int y){
    return in[x] < in[y];
}

void dfs_tour(int node, int parent)
{
    in[node] = ++timer;
    for(auto i: graph[node])
    {
        if(i == parent)continue;
        dfs_tour(i, node);
    }
    out[node] = ++timer;
}

void dfs(int node, int parent, int counter)
{
    for(auto i: graph[node])
    {
        if(i == parent)continue;

        if(a[node] == region)dfs(i, node, counter + 1);
        else dfs(i, node, counter);
    }

    ans[region][a[node]] += counter;
}

void solve(int r1, int r2)
{
    int counter = 0;
    for(auto i: v[r1])
    {
        int l = 0, r = v[r2].size() - 1, ansl = -1, ansr = -1;

        while(l <= r)
        {
            int mid = l + (r - l) / 2;

            if(in[v[r2][mid]] > in[i])
            {
                ansl = mid;
                r = mid - 1;
            }
            else l = mid + 1;
        }

        l = 0; r = v[r2].size() - 1;

        while(l <= r)
        {
            int mid = l + (r - l) / 2;

            if(in[v[r2][mid]] < out[i])
            {
                ansr = mid;
                l = mid + 1;
            }
            else r = mid - 1;
        }

        if(ansl != -1 && ansr != -1)counter += ansr - ansl + 1;
    }

    cout << counter << endl;
}

signed main()
{
    std::ios_base::sync_with_stdio(false);
    std::cin.tie(nullptr); std::cout.tie(nullptr);

    cin >> n >> r >> q;

    cin >> a[1];
    for(int i = 2; i <= n; i++)
    {
        int par; cin >> par >> a[i];
        graph[par].push_back(i);
    }

    dfs_tour(1, 0);

    for(int i = 1; i <= n; i++)
    {
        v[a[i]].push_back(i);
    }

    for(int i = 1; i <= r; i++)
    {
        sort(v[i].begin(), v[i].end(), cmp);
        if(v[i].size() > sqrt(n))
        {
            big[i] = true;
            region = i; dfs(1, 0, 0);
        }
    }

    while(q--)
    {
        int r1, r2; cin >> r1 >> r2;

        if(big[r1])cout << ans[r1][r2] << endl;
        else solve(r1, r2);
    }

    return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...