Submission #672737

#TimeUsernameProblemLanguageResultExecution timeMemory
672737finn__Regions (IOI09_regions)C++17
100 / 100
4192 ms42692 KiB
#include <bits/stdc++.h>
using namespace std;

vector<vector<unsigned>> g;
vector<unsigned> h, y, z;

unsigned preorder(unsigned u, unsigned p, unsigned i = 0)
{
    y[u] = i++;
    for (unsigned v : g[u])
        if (v != p)
            i = preorder(v, u, i);
    z[u] = i;
    return i;
}

struct interval
{
    unsigned i, j, v;
};

void buidl_intv(
    unsigned u, unsigned p, vector<vector<interval>> &intv,
    vector<unsigned> &reg_end, vector<unsigned> &num_reg)
{
    if (y[u])
    {
        intv[h[u]].push_back({reg_end[h[u]], y[u] - 1, num_reg[h[u]]});
        reg_end[h[u]] = y[u];
    }
    num_reg[h[u]]++;

    for (unsigned v : g[u])
        if (v != p)
            buidl_intv(v, u, intv, reg_end, num_reg);

    intv[h[u]].push_back({reg_end[h[u]], z[u] - 1, num_reg[h[u]]});
    reg_end[h[u]] = z[u];
    num_reg[h[u]]--;
}

int main()
{
    size_t n, r, q;
    cin >> n >> r >> q;

    g = vector<vector<unsigned>>(n);
    h = vector<unsigned>(n);
    cin >> h[0];
    h[0]--;

    for (unsigned u = 1; u < n; u++)
    {
        unsigned v;
        cin >> v >> h[u];
        h[u]--;
        g[u].push_back(v - 1);
        g[v - 1].push_back(u);
    }

    y = vector<unsigned>(n);
    z = vector<unsigned>(n);
    assert(preorder(0, -1) == n);

    vector<vector<unsigned>> reg(r);

    for (unsigned i = 0; i < n; i++)
        reg[h[i]].push_back(y[i]);
    for (vector<unsigned> &v : reg)
        sort(v.begin(), v.end());

    vector<vector<interval>> intv(r);
    vector<unsigned> reg_end(r, 0), num_reg(r, 0);
    buidl_intv(0, -1, intv, reg_end, num_reg);

    for (size_t i = 0; i < q; i++)
    {
        unsigned r1, r2;
        cin >> r1 >> r2;
        r1--, r2--;

        auto it = intv[r1].begin();
        auto jt = reg[r2].begin();
        unsigned total = 0;

        while (it != intv[r1].end() && jt != reg[r2].end())
        {
            if (*jt < it->i)
                jt++;
            else if (it->j < *jt)
                it++;
            else if (it->i <= *jt && *jt <= it->j)
            {
                total += it->v;
                jt++;
            }
        }

        cout << total << endl;
    }
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...