Submission #535455

#TimeUsernameProblemLanguageResultExecution timeMemory
535455chenwzRegions (IOI09_regions)C++11
100 / 100
2273 ms46736 KiB
// IOI2009 – Regions #include <bits/stdc++.h> using namespace std; using VI = vector<int>; using IP = pair<int, int>; using LL = long long; struct Emp { // Employee int tin, reg; // tin(dfs时间戳), reg(区域) VI ch; // 被指导的 }; struct Reg { VI ids; // 其中的Employee vector<IP> ranges; // 时间戳排序的 {ID: cnt} cnt是到祖先的同reg数量 int cnt; }; vector<Emp> ES; vector<Reg> RS; void dfs(int u, int &timer) { auto &r = RS[ES[u].reg]; r.ids.push_back(timer++), r.ranges.push_back({timer, ++r.cnt}); for (int v : ES[u].ch) dfs(v, timer); r.ranges.push_back({timer, --r.cnt}); // 子树u终点,cnt离开子树后祖先的r数量 } LL query_by_id(const Reg &r1, const Reg &r2) { LL ans = 0; // 针对r2中的每个id,找寻r1中包含id的区间 O(|R2|log|R1|) auto &rv = r1.ranges; for (int u : r2.ids) { // 找到第一个起点在u之后的区间,它之前的区间就是目标 auto it = lower_bound(begin(rv), end(rv), make_pair(u, INT_MAX)); if (it != rv.begin()) ans += prev(it)->second; } return ans; } LL query_by_range(const Reg &r1, const Reg &r2) { LL ans = 0; // 针对r1中的每个区间,看看r2中有多少个id在其中, O(|R1|log(|R2|)) const auto &rv = r2.ids; for (size_t i = 0; i + 1 < r1.ranges.size(); i++) { int p1 = r1.ranges[i].first, p2 = r1.ranges[i + 1].first; auto it1 = lower_bound(begin(rv), end(rv), p1), it2 = lower_bound(begin(rv), end(rv), p2); ans += r1.ranges[i].second * (it2 - it1); } return ans; } LL query_stitch(const Reg &r1, const Reg &r2) { LL ans = 0; // 针对r2中的每个id,线性在r1中查找,时间O(|R1|+|R2|) auto p = r2.ids.begin(); if (r1.ranges.empty()) return ans; while (p != r2.ids.end() && *p < r1.ranges[0].first) p++; for (size_t i = 0; i + 1 < r1.ranges.size() && p != r2.ids.end(); i++) { auto np = p; while (np != r2.ids.end() && *np < r1.ranges[i + 1].first) np++; ans += r1.ranges[i].second * (np - p), p = np; } return ans; } LL solve(int r1, int r2) { static map<IP, LL> cache; IP key(r1, r2); if (cache.count(key)) return cache[key]; const Reg &reg1 = RS[r1], &reg2 = RS[r2]; int sz1 = reg1.ids.size(), sz2 = reg2.ids.size(); int costs[3] = {sz1 * ((int)log2(sz2) + 2) * 5, sz2 * ((int)log2(sz1) + 2) * 5, sz1 + sz2}; int k = min_element(costs, costs + 3) - costs; if (k == 0) return cache[key] = query_by_range(reg1, reg2); if (k == 1) return cache[key] = query_by_id(reg1, reg2); return cache[key] = query_stitch(reg1, reg2); } int main() { ios::sync_with_stdio(false), cin.tie(0); int N, R, Q, timer = 0; cin >> N >> R >> Q, ES.resize(N), RS.resize(R); cin >> ES[0].reg, --ES[0].reg; for (int i = 1, fa; i < N; i++) cin >> fa >> ES[i].reg, --ES[i].reg, ES[fa - 1].ch.push_back(i); dfs(0, timer); // for (int i = 0; i < R; i++) { // if (RS[i].ids.empty()) continue; // printf("reg %d: ", i + 1); // for (int id : RS[i].ids) printf(" %d, ", id + 1); // puts(""); // for (auto p : RS[i].ranges) { // printf("%d(%d), ", p.first, p.second); // } // puts(""); // } for (int q = 0, r1, r2; q < Q; q++) { cin >> r1 >> r2; // printf("query : %d-%d\n", r1, r2); cout << solve(r1 - 1, r2 - 1) << endl; } return 0; }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...