답안 #850299

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
850299 2023-09-16T09:42:18 Z Shreyan_Paliwal Regions (IOI09_regions) C++17
100 / 100
1754 ms 70408 KB
// #pragma GCC optimize("Ofast")
#include <bits/stdc++.h>
using namespace std;

template<typename T>
string tostr(const T& value) {
    ostringstream oss;
    oss << fixed << setprecision(7);
    oss << value;
    return oss.str();
}

template<typename... Args>
string fstr(const string& format, Args... args) {
    string result = format;
    size_t pos = 0;
    size_t argIndex = 0;

    auto replaceArg = [&](const auto& arg) {
        pos = result.find("{}", pos);
        if (pos != string::npos) {
            result.replace(pos, 2, tostr(arg));
            ++argIndex;
        }
    };

    (replaceArg(args), ...);

    return result;
}

#define int long long
const int INF = LLONG_MAX / 2 - INT_MAX;
const int INFmem = 0x3f3f3f3f'3f3f3f3f;
const int inf = INT_MAX / 2;
const int infmem = 0x3f3f3f3f;
const int MOD = 1e9 + 7;

template<int SZ, int SZ2> struct Tree {
    vector<int> adj[SZ];
    int r[SZ];
    void init2(int n) {
        cin >> r[0]; r[0]--;
        for (int i = 1; i < n; i++) {
            int a; cin >> a >> r[i]; r[i]--;
            adj[a].push_back(i);
        }
    }
    int st[SZ], en[SZ], cnt=0;
    vector<int> sts[SZ2], ens[SZ2];
    void flatten(int nd, int par) {
        st[nd] = cnt++;
        sts[r[nd]].push_back(st[nd]);
        for (auto i : adj[nd])
            if (i != par)
                flatten(i, nd, st, en, cnt);
        en[nd] = cnt;
        ens[r[nd]].push_back(en[nd]);
    }

    int s[SZ2];
};



const int maxn = 200005;
const int maxr = 25005;
const int sqrtn = 450;
const int sqrtr = 160;

// input
int n, r, q;
vector<int> adj[maxn];
int re[maxn];

int st[maxn], en[maxn], cnt = 0;
vector<int> sts[maxr], ens[maxr];
int s[maxr], prec[sqrtr][maxr], prec2[maxr][sqrtr], cnt2 = 0;

void dfs(int nd, int par) {
    st[nd] = cnt++;
    sts[re[nd]].push_back(st[nd]);

    for (auto i : adj[nd])
        if (i != par)
            dfs(i, nd);

    en[nd] = cnt;
    ens[re[nd]].push_back(en[nd]);
}

int curreg;
int dfs2(int nd, int par, int num) {
    // FOR prec[CURRENT][SOMETHING], update if current is curreg
    if (re[nd] != curreg) prec[s[curreg]][re[nd]] += num;
    num += re[nd] == curreg;

    int num_current = re[nd] == curreg;
    for (auto i : adj[nd])
        if (i != par)
            num_current += dfs2(i, nd, num);

    // FOR prec2[SOMETHING][CURRENT], return number of currents
    if (re[nd] != curreg) prec2[re[nd]][s[curreg]] += num_current;
    return num_current;    
}

int qry(int a, int b) {
    // cout << "-----" << endl << a << ' ' << b << endl;
    int cnt = 0, num = 0;
    int c1 = 0, c2 = 0, c3 = 0;
    // looping through sts[a], sts[b], ens[a]
    while (c2 < sts[b].size() && c3 < ens[a].size()) {
        // cout << sts[a][c1] << ' ' << sts[b][c2] << ' ' << ens[a][c3] << ' ' << cnt << ' ' << num << endl;
        int nxt = min(min((c1 < sts[a].size() ? sts[a][c1] : INF), sts[b][c2]), ens[a][c3]);
        if (ens[a][c3] == nxt) { cnt--; c3++; continue; }
        if (c1 < sts[a].size() && sts[a][c1] == nxt) { cnt++; c1++; continue; }
        if (sts[b][c2] == nxt) { num += cnt; c2++; continue; }
    }
    // cout << "____" << endl;
    return num;
}

void solve() {
    fill(s, s+maxr, -1); // set regions to -1

    cin >> n >> r >> q;
    cin >> re[0]; re[0]--;
    for (int i = 1; i < n; i++) 
        { int a; cin >> a >> re[i]; a--, re[i]--; adj[a].push_back(i); }

    // precompute
    dfs(0, 0);
    for (int i = 0; i < r; i++)
        if (sts[i].size() >= sqrtn) {
            s[i] = cnt2++; curreg = i;
            dfs2(0, 0, 0);   
        }

    // for (int i = 0; i < n; i++) {
    //     cout << "i " << st[i] << ' ';
    //     for (auto j : adj[i]) cout << st[j] << ' ';
    //     cout << endl;
    // }

    // for (int i = 0; i < 3; i++) {
    //     cout << "STS "; for (auto j : sts[i]) cout << j << ' '; cout << " | ENS "; for (auto j : ens[i]) cout << j << ' ';
    //     cout << endl;
    // }

    for (int i = 0; i < q; i++) {
        int a, b; cin >> a >> b; a--, b--;
        if (s[a] != -1) {cout << prec[s[a]][b] << endl; continue;}
        if (s[b] != -1) {cout << prec2[a][s[b]] << endl; continue;}
        cout << qry(a, b) << endl;
    }
}

// #define LOCAL
// #define CODEFORCES
signed main() {
    #ifndef LOCAL
    cin.tie(nullptr) -> ios::sync_with_stdio(false);
    #endif
    #ifdef LOCAL
    freopen("main.in", "r", stdin);
    #endif
    int t; 
    #ifdef CODEFORCES
    cin >> t;
    #endif
    #ifndef CODEFORCES
    t=1;
    #endif
    for (int i = 1; i <= t; i++) {
        #ifdef LOCAL
        cout << fstr("----- Case {} -----", i) << endl;
        auto startTime = clock();
        #endif 
        solve();
        #ifdef LOCAL
        cout << fstr("RUNTIME: {}", (double)(clock() - startTime)/CLOCKS_PER_SEC) << endl;
        cout << fstr("----- END CASE {} -----", i) << endl;
        #endif
        #ifdef LOCAL
        #endif
    }
    return 0;
}

Compilation message

regions.cpp: In function 'long long int qry(long long int, long long int)':
regions.cpp:113:15: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<long long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  113 |     while (c2 < sts[b].size() && c3 < ens[a].size()) {
      |            ~~~^~~~~~~~~~~~~~~
regions.cpp:113:37: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<long long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  113 |     while (c2 < sts[b].size() && c3 < ens[a].size()) {
      |                                  ~~~^~~~~~~~~~~~~~~
regions.cpp:115:31: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<long long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  115 |         int nxt = min(min((c1 < sts[a].size() ? sts[a][c1] : INF), sts[b][c2]), ens[a][c3]);
      |                            ~~~^~~~~~~~~~~~~~~
regions.cpp:117:16: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<long long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  117 |         if (c1 < sts[a].size() && sts[a][c1] == nxt) { cnt++; c1++; continue; }
      |             ~~~^~~~~~~~~~~~~~~
# 결과 실행 시간 메모리 Grader output
1 Correct 3 ms 12108 KB Output is correct
2 Correct 3 ms 11864 KB Output is correct
3 Correct 5 ms 11864 KB Output is correct
4 Correct 5 ms 12120 KB Output is correct
5 Correct 7 ms 12120 KB Output is correct
6 Correct 10 ms 12120 KB Output is correct
7 Correct 14 ms 12120 KB Output is correct
8 Correct 21 ms 12120 KB Output is correct
9 Correct 27 ms 12632 KB Output is correct
10 Correct 43 ms 12468 KB Output is correct
11 Correct 63 ms 12920 KB Output is correct
12 Correct 70 ms 13516 KB Output is correct
13 Correct 93 ms 13252 KB Output is correct
14 Correct 115 ms 13824 KB Output is correct
15 Correct 135 ms 16936 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 568 ms 23256 KB Output is correct
2 Correct 614 ms 22100 KB Output is correct
3 Correct 1055 ms 25336 KB Output is correct
4 Correct 149 ms 13876 KB Output is correct
5 Correct 202 ms 15520 KB Output is correct
6 Correct 377 ms 37788 KB Output is correct
7 Correct 686 ms 34812 KB Output is correct
8 Correct 679 ms 56844 KB Output is correct
9 Correct 1138 ms 22068 KB Output is correct
10 Correct 1504 ms 70408 KB Output is correct
11 Correct 1754 ms 21624 KB Output is correct
12 Correct 678 ms 45208 KB Output is correct
13 Correct 943 ms 45944 KB Output is correct
14 Correct 1276 ms 55604 KB Output is correct
15 Correct 1472 ms 60396 KB Output is correct
16 Correct 1402 ms 67440 KB Output is correct
17 Correct 1565 ms 66452 KB Output is correct