Submission #853777

#TimeUsernameProblemLanguageResultExecution timeMemory
853777FairyWinxClosing Time (IOI23_closing)C++17
52 / 100
150 ms40172 KiB
#include <bits/stdc++.h> #define all(a) a.begin(), a.end() #define rall(a) a.rbegin(), a.rend() using namespace std; vector<vector<pair<int, int>>> G; vector<int> ban; void dfs_calc_dist(int v, int par, long long d, vector<long long> &dist) { dist[v] = d; for (auto i : G[v]) { if (i.first != par) { dfs_calc_dist(i.first, v, d + i.second, dist); } } } void dfs_find_par(int v, int p, vector<int> &par) { par[v] = p; for (auto i : G[v]) { if (i.first != p) { dfs_find_par(i.first, v, par); } } } void init(int n) { G.clear(); ban.clear();; G.resize(n); ban.resize(n); } pair<set<pair<long long, int>>, set<pair<long long, int>>> build(long long k, vector<pair<long long, int>> &value1, vector<pair<long long, int>> &value2, vector<long long> &dist1, vector<long long> &dist2) { sort(rall(value1)); sort(rall(value2)); vector<pair<int, int>> ans; set<pair<long long, int>> used1, used2; while (value1.size() || value2.size()) { if (!value1.size()) { if (value2.back().first > k) { break; } else { int v = value2.back().second; used2.emplace(max(dist1[v], dist2[v]), v); k -= value2.back().first; value2.pop_back(); } } else if (!value2.size()) { if (value1.back().first > k) { break; } else { int v = value1.back().second; used1.emplace(abs(dist1[v] - dist2[v]), v); k -= value1.back().first; value1.pop_back(); } } else { if (k >= value1.back().first && (k < value2.back().first || (2 * value1.back().first < value2.back().first))) { // value1 оптимальнее int v = value1.back().second; used1.emplace(abs(dist1[v] - dist2[v]), v); k -= value1.back().first; value1.pop_back(); } else if (k >= value2.back().first) { int v = value2.back().second; used2.emplace(max(dist1[v], dist2[v]), v); k -= value2.back().first; value2.pop_back(); } else { break; } } } return {used1, used2}; } void update_ans(long long k, long long &cur_sum, set<pair<long long, int>> &used1, set<pair<long long, int>> &used2, set<pair<long long, int>> &unused1, set<pair<long long, int>> &unused2) { // aboba while (cur_sum > k) { if (!used1.size()) { cur_sum -= used2.rbegin()->first; unused2.emplace(*used2.rbegin()); used2.erase(*used2.rbegin()); } else if (!used2.size()) { cur_sum -= used1.rbegin()->first; unused1.emplace(*used1.rbegin()); used1.erase(*used1.rbegin()); } else { if (used1.rbegin()->first * 2 < used2.rbegin()->first) { cur_sum -= used2.rbegin()->first; unused2.emplace(*used2.rbegin()); used2.erase(*used2.rbegin()); } else { cur_sum -= used1.rbegin()->first; unused1.emplace(*used1.rbegin()); used1.erase(*used1.rbegin()); } } } while (unused1.size() || unused2.size()) { if (!unused1.size() || k - cur_sum < unused1.begin()->first) { if (unused2.size() && k - cur_sum >= unused2.begin()->first) { cur_sum += unused2.begin()->first; used2.insert(*unused2.begin()); unused2.erase(unused2.begin()); continue; } else { break; } } if (!unused2.size() || k - cur_sum < unused2.begin()->first) { if (unused1.size() && k - cur_sum >= unused1.begin()->first) { cur_sum += unused1.begin()->first; used1.insert(*unused1.begin()); unused1.erase(unused1.begin()); continue; } else { break; } } if (unused1.begin()->first * 2 < unused2.begin()->first) { cur_sum += unused1.begin()->first; used1.insert(*unused1.begin()); unused1.erase(unused1.begin()); } else { cur_sum += unused2.begin()->first; used2.insert(*unused2.begin()); unused2.erase(unused2.begin()); } } } int get_ans(long long k, long long cur_sum, set<pair<long long, int>> &used1, set<pair<long long, int>> &used2, set<pair<long long, int>> &unused2) { if (used1.size() && unused2.size()) { if (k - cur_sum + used1.rbegin()->first >= unused2.begin()->first) { return used1.size() + used2.size() * 2 + 1; } } return used1.size() + used2.size() * 2; } int max_score(int n, int x, int y, long long k, vector<int> u, vector<int> v, vector<int> w) { init(n); for (int i = 0; i < n - 1; ++i) { G[u[i]].emplace_back(v[i], w[i]); G[v[i]].emplace_back(u[i], w[i]); } vector<long long> dist1(n), dist2(n); dfs_calc_dist(x, -1, 0, dist1); dfs_calc_dist(y, -1, 0, dist2); vector<pair<long long, int>> value; for (int i = 0; i < n; ++i) { value.emplace_back(dist1[i], i); value.emplace_back(dist2[i], i); } sort(all(value)); int ans = 0; { vector<int> used(n); long long sum = 0; int cnt = 0; for (auto c : value) { if (used[c.second]) continue; sum += c.first; ++cnt; used[c.second] = 1; if (sum > k) break; ans = max(ans, cnt); } } vector<int> path; { vector<int> par(n); dfs_find_par(x, -1, par); int tmp = y; while (tmp != -1) { path.push_back(tmp); tmp = par[tmp]; } } vector<int> used(n); int cnt = 0; for (int v : path) { k -= min(dist1[v], dist2[v]); ++cnt; used[v] = 1; } if (k < 0) return ans; vector<pair<long long, int>> value1, value2; for (int i = 0; i < n; ++i) { if (used[i]) { value1.emplace_back(abs(dist1[i] - dist2[i]), i); } else { value2.emplace_back(max(dist1[i], dist2[i]), i); } } auto [used1, used2] = build(k, value1, value2, dist1, dist2); set<pair<long long, int>> unused1; set<pair<long long, int>> unused2; { vector<int> in_ans(n); for (auto i : used2) in_ans[i.second] = 1; for (int i = 0; i < n; ++i) { if (!in_ans[i] && !used[i]) { unused2.emplace(max(dist1[i], dist2[i]), i); } } } { vector<int> in_ans(n); for (auto i : used1) in_ans[i.second] = 1; for (int i = 0; i < n; ++i) { if (!in_ans[i] && used[i]) { unused1.emplace(abs(dist1[i] - dist2[i]), i); } } } long long cur_sum = 0; for (auto i : used1) cur_sum += i.first; for (auto i : used2) cur_sum += i.first; ans = max(ans, cnt + get_ans(k, cur_sum, used1, used2, unused2)); for (auto c : value) { if (used[c.second]) continue; used[c.second] = 1; ++cnt; k -= c.first; if (k < 0) break; if (used2.count(make_pair(max(dist1[c.second], dist2[c.second]), c.second))) cur_sum -= max(dist1[c.second], dist2[c.second]); used2.erase(make_pair(max(dist1[c.second], dist2[c.second]), c.second)); unused2.erase(make_pair(max(dist1[c.second], dist2[c.second]), c.second)); unused1.emplace(abs(dist1[c.second] - dist2[c.second]), c.second); update_ans(k, cur_sum, used1, used2, unused1, unused2); ans = max(ans, cnt + get_ans(k, cur_sum, used1, used2, unused2)); } return ans; } #ifdef LOCAL int main() { // BEGIN SECRET { std::string in_secret = "cc61ad56a4797fb3f5c9529f73ce6fcedd85669b"; std::string out_secret = "081ce3c351cbf526b37954b9ad30f2b531a7585c"; char secret[1000]; assert(1 == scanf("%s", secret)); if (std::string(secret) != in_secret) { printf("%s\n", out_secret.c_str()); printf("SV\n"); fclose(stdout); return 0; } } // END SECRET int Q; assert(1 == scanf("%d", &Q)); std::vector<int> N(Q), X(Q), Y(Q); std::vector<long long> K(Q); std::vector<std::vector<int>> U(Q), V(Q), W(Q); for (int q = 0; q < Q; q++) { assert(4 == scanf("%d %d %d %lld", &N[q], &X[q], &Y[q], &K[q])); U[q].resize(N[q] - 1); V[q].resize(N[q] - 1); W[q].resize(N[q] - 1); for (int i = 0; i < N[q] - 1; ++i) { assert(3 == scanf("%d %d %d", &U[q][i], &V[q][i], &W[q][i])); } } fclose(stdin); std::vector<int> result(Q); for (int q = 0; q < Q; q++) { result[q] = max_score(N[q], X[q], Y[q], K[q], U[q], V[q], W[q]); } // BEGIN SECRET { std::string out_secret = "081ce3c351cbf526b37954b9ad30f2b531a7585c"; printf("%s\n", out_secret.c_str()); printf("OK\n"); } // END SECRET for (int q = 0; q < Q; q++) { printf("%d\n", result[q]); } fclose(stdout); return 0; } #endif
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...