Submission #850299

#TimeUsernameProblemLanguageResultExecution timeMemory
850299Shreyan_PaliwalRegions (IOI09_regions)C++17
100 / 100
1754 ms70408 KiB
// #pragma GCC optimize("Ofast")
#include <bits/stdc++.h>
using namespace std;

template<typename T>
string tostr(const T& value) {
    ostringstream oss;
    oss << fixed << setprecision(7);
    oss << value;
    return oss.str();
}

template<typename... Args>
string fstr(const string& format, Args... args) {
    string result = format;
    size_t pos = 0;
    size_t argIndex = 0;

    auto replaceArg = [&](const auto& arg) {
        pos = result.find("{}", pos);
        if (pos != string::npos) {
            result.replace(pos, 2, tostr(arg));
            ++argIndex;
        }
    };

    (replaceArg(args), ...);

    return result;
}

#define int long long
const int INF = LLONG_MAX / 2 - INT_MAX;
const int INFmem = 0x3f3f3f3f'3f3f3f3f;
const int inf = INT_MAX / 2;
const int infmem = 0x3f3f3f3f;
const int MOD = 1e9 + 7;

template<int SZ, int SZ2> struct Tree {
    vector<int> adj[SZ];
    int r[SZ];
    void init2(int n) {
        cin >> r[0]; r[0]--;
        for (int i = 1; i < n; i++) {
            int a; cin >> a >> r[i]; r[i]--;
            adj[a].push_back(i);
        }
    }
    int st[SZ], en[SZ], cnt=0;
    vector<int> sts[SZ2], ens[SZ2];
    void flatten(int nd, int par) {
        st[nd] = cnt++;
        sts[r[nd]].push_back(st[nd]);
        for (auto i : adj[nd])
            if (i != par)
                flatten(i, nd, st, en, cnt);
        en[nd] = cnt;
        ens[r[nd]].push_back(en[nd]);
    }

    int s[SZ2];
};



const int maxn = 200005;
const int maxr = 25005;
const int sqrtn = 450;
const int sqrtr = 160;

// input
int n, r, q;
vector<int> adj[maxn];
int re[maxn];

int st[maxn], en[maxn], cnt = 0;
vector<int> sts[maxr], ens[maxr];
int s[maxr], prec[sqrtr][maxr], prec2[maxr][sqrtr], cnt2 = 0;

void dfs(int nd, int par) {
    st[nd] = cnt++;
    sts[re[nd]].push_back(st[nd]);

    for (auto i : adj[nd])
        if (i != par)
            dfs(i, nd);

    en[nd] = cnt;
    ens[re[nd]].push_back(en[nd]);
}

int curreg;
int dfs2(int nd, int par, int num) {
    // FOR prec[CURRENT][SOMETHING], update if current is curreg
    if (re[nd] != curreg) prec[s[curreg]][re[nd]] += num;
    num += re[nd] == curreg;

    int num_current = re[nd] == curreg;
    for (auto i : adj[nd])
        if (i != par)
            num_current += dfs2(i, nd, num);

    // FOR prec2[SOMETHING][CURRENT], return number of currents
    if (re[nd] != curreg) prec2[re[nd]][s[curreg]] += num_current;
    return num_current;    
}

int qry(int a, int b) {
    // cout << "-----" << endl << a << ' ' << b << endl;
    int cnt = 0, num = 0;
    int c1 = 0, c2 = 0, c3 = 0;
    // looping through sts[a], sts[b], ens[a]
    while (c2 < sts[b].size() && c3 < ens[a].size()) {
        // cout << sts[a][c1] << ' ' << sts[b][c2] << ' ' << ens[a][c3] << ' ' << cnt << ' ' << num << endl;
        int nxt = min(min((c1 < sts[a].size() ? sts[a][c1] : INF), sts[b][c2]), ens[a][c3]);
        if (ens[a][c3] == nxt) { cnt--; c3++; continue; }
        if (c1 < sts[a].size() && sts[a][c1] == nxt) { cnt++; c1++; continue; }
        if (sts[b][c2] == nxt) { num += cnt; c2++; continue; }
    }
    // cout << "____" << endl;
    return num;
}

void solve() {
    fill(s, s+maxr, -1); // set regions to -1

    cin >> n >> r >> q;
    cin >> re[0]; re[0]--;
    for (int i = 1; i < n; i++) 
        { int a; cin >> a >> re[i]; a--, re[i]--; adj[a].push_back(i); }

    // precompute
    dfs(0, 0);
    for (int i = 0; i < r; i++)
        if (sts[i].size() >= sqrtn) {
            s[i] = cnt2++; curreg = i;
            dfs2(0, 0, 0);   
        }

    // for (int i = 0; i < n; i++) {
    //     cout << "i " << st[i] << ' ';
    //     for (auto j : adj[i]) cout << st[j] << ' ';
    //     cout << endl;
    // }

    // for (int i = 0; i < 3; i++) {
    //     cout << "STS "; for (auto j : sts[i]) cout << j << ' '; cout << " | ENS "; for (auto j : ens[i]) cout << j << ' ';
    //     cout << endl;
    // }

    for (int i = 0; i < q; i++) {
        int a, b; cin >> a >> b; a--, b--;
        if (s[a] != -1) {cout << prec[s[a]][b] << endl; continue;}
        if (s[b] != -1) {cout << prec2[a][s[b]] << endl; continue;}
        cout << qry(a, b) << endl;
    }
}

// #define LOCAL
// #define CODEFORCES
signed main() {
    #ifndef LOCAL
    cin.tie(nullptr) -> ios::sync_with_stdio(false);
    #endif
    #ifdef LOCAL
    freopen("main.in", "r", stdin);
    #endif
    int t; 
    #ifdef CODEFORCES
    cin >> t;
    #endif
    #ifndef CODEFORCES
    t=1;
    #endif
    for (int i = 1; i <= t; i++) {
        #ifdef LOCAL
        cout << fstr("----- Case {} -----", i) << endl;
        auto startTime = clock();
        #endif 
        solve();
        #ifdef LOCAL
        cout << fstr("RUNTIME: {}", (double)(clock() - startTime)/CLOCKS_PER_SEC) << endl;
        cout << fstr("----- END CASE {} -----", i) << endl;
        #endif
        #ifdef LOCAL
        #endif
    }
    return 0;
}

Compilation message (stderr)

regions.cpp: In function 'long long int qry(long long int, long long int)':
regions.cpp:113:15: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<long long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  113 |     while (c2 < sts[b].size() && c3 < ens[a].size()) {
      |            ~~~^~~~~~~~~~~~~~~
regions.cpp:113:37: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<long long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  113 |     while (c2 < sts[b].size() && c3 < ens[a].size()) {
      |                                  ~~~^~~~~~~~~~~~~~~
regions.cpp:115:31: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<long long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  115 |         int nxt = min(min((c1 < sts[a].size() ? sts[a][c1] : INF), sts[b][c2]), ens[a][c3]);
      |                            ~~~^~~~~~~~~~~~~~~
regions.cpp:117:16: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<long long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  117 |         if (c1 < sts[a].size() && sts[a][c1] == nxt) { cnt++; c1++; continue; }
      |             ~~~^~~~~~~~~~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...