Submission #840572

#TimeUsernameProblemLanguageResultExecution timeMemory
840572arbuzickClosing Time (IOI23_closing)C++17
29 / 100
1090 ms105812 KiB
#include <bits/stdc++.h> using namespace std; constexpr int maxn = 2e5 + 5; constexpr long long inf = (long long)(1e18) + 7; vector<pair<int, int>> g[maxn]; long long dist_x[maxn], dist_y[maxn]; int prv[maxn]; vector<int> path; void dfs(int v, long long* dist) { for (auto [u, c] : g[v]) { if (u != prv[v]) { prv[u] = v; dist[u] = dist[v] + c; dfs(u, dist); } } } void calc_dist(int x, int y) { prv[x] = x; dfs(x, dist_x); prv[y] = y; dfs(y, dist_y); int nw = x; while (nw != y) { path.push_back(nw); nw = prv[nw]; } path.push_back(y); } bool used[maxn]; constexpr int maxn2 = 3e3 + 5; long long dp[maxn2][maxn2 * 2][4]; long long res[maxn2 * 2]; int sz[maxn]; void calc_sz(int v, int prv) { sz[v] = 1; for (auto [u, c] : g[v]) { if (!used[u] && u != prv) { calc_sz(u, v); sz[v] += sz[u]; } } } void calc_dp(int v, int prv) { for (auto [u, c] : g[v]) { if (!used[u] && u != prv) { calc_dp(u, v); for (int mask = 0; mask < 4; ++mask) { for (int cnt_nw = sz[v] * 2; cnt_nw >= 0; --cnt_nw) { if (dp[v][cnt_nw][mask] == inf) { continue; } for (int cnt_add = 0; cnt_add <= sz[u] * 2 && cnt_nw + cnt_add <= sz[v] * 2; ++cnt_add) { dp[v][cnt_nw + cnt_add][mask] = min(dp[v][cnt_nw + cnt_add][mask], dp[v][cnt_nw][mask] + dp[u][cnt_add][mask]); } } } } } for (int cnt = 0; cnt <= sz[v] * 2; ++cnt) { dp[v][cnt][1] = min(dp[v][cnt][1], dp[v][cnt][0]); dp[v][cnt][2] = min(dp[v][cnt][2], dp[v][cnt][0]); dp[v][cnt][3] = min(dp[v][cnt][3], min(dp[v][cnt][1], dp[v][cnt][2])); } } int max_score(int n, int x, int y, long long k, vector<int> u, vector<int> v, vector<int> w) { path.clear(); for (int i = 0; i < n; ++i) { g[i].clear(); dist_x[i] = dist_y[i] = 0; used[i] = false; } for (int i = 0; i < n - 1; ++i) { g[u[i]].emplace_back(v[i], w[i]); g[v[i]].emplace_back(u[i], w[i]); } calc_dist(x, y); int ans = 0; long long k_old = k; multiset<long long> ms; for (int i = 0; i < n; ++i) { ms.insert(dist_x[i]); ms.insert(dist_y[i]); } while (!ms.empty() && *ms.begin() <= k) { k -= *ms.begin(); ms.erase(ms.begin()); ans++; } if (n > maxn2) { return ans; } for (int i = 0; i < n; ++i) { for (int cnt = 0; cnt < maxn2; ++cnt) { for (int mask = 0; mask < 4; ++mask) { dp[i][cnt][mask] = inf; } } dp[i][0][0] = 0; dp[i][1][1] = dist_x[i]; dp[i][1][2] = dist_y[i]; dp[i][2][3] = max(dist_x[i], dist_y[i]); } k = k_old; for (auto v : path) { used[v] = true; } for (auto v : path) { calc_sz(v, v); for (int cnt = 0; cnt < maxn2; ++cnt) { for (int mask = 0; mask < 4; ++mask) { dp[v][cnt][mask] = inf; } } dp[v][0][0] = dp[v][0][1] = dp[v][0][2] = dp[v][0][3] = 0; calc_dp(v, v); // cout << "!" << ' ' << v << ' ' << dp[v][2][0] << ' ' << dp[v][2][1] << ' ' << dp[v][2][2] << ' ' << dp[v][2][3] << endl; } for (int l = 0; l < (int)path.size(); ++l) { for (int r = l + 1; r <= (int)path.size(); ++r) { res[0] = 0; for (int i = 1; i <= n * 2; ++i) { res[i] = inf; } int add_ans = 0; long long add_k = 0; for (int i = 0; i < l; ++i) { add_ans++; add_k += dist_x[path[i]]; for (int cnt_nw = n * 2; cnt_nw >= 0; --cnt_nw) { if (res[cnt_nw] == inf) { continue; } for (int cnt_add = 0; cnt_add <= sz[path[i]] * 2 && cnt_nw + cnt_add <= n * 2; ++cnt_add) { res[cnt_nw + cnt_add] = min(res[cnt_nw + cnt_add], res[cnt_nw] + dp[path[i]][cnt_add][1]); } } } for (int i = l; i < r; ++i) { add_ans += 2; add_k += max(dist_x[path[i]], dist_y[path[i]]); for (int cnt_nw = n * 2; cnt_nw >= 0; --cnt_nw) { if (res[cnt_nw] == inf) { continue; } for (int cnt_add = 0; cnt_add <= sz[path[i]] * 2 && cnt_nw + cnt_add <= n * 2; ++cnt_add) { res[cnt_nw + cnt_add] = min(res[cnt_nw + cnt_add], res[cnt_nw] + dp[path[i]][cnt_add][3]); } } } for (int i = r; i < (int)path.size(); ++i) { add_ans++; add_k += dist_y[path[i]]; for (int cnt_nw = n * 2; cnt_nw >= 0; --cnt_nw) { if (res[cnt_nw] == inf) { continue; } for (int cnt_add = 0; cnt_add <= sz[path[i]] * 2 && cnt_nw + cnt_add <= n * 2; ++cnt_add) { res[cnt_nw + cnt_add] = min(res[cnt_nw + cnt_add], res[cnt_nw] + dp[path[i]][cnt_add][2]); } } } for (int cnt = 0; cnt <= n * 2; ++cnt) { if (res[cnt] + add_k <= k) { ans = max(ans, cnt + add_ans); } } } } return ans; }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...