Submission #1105572

#TimeUsernameProblemLanguageResultExecution timeMemory
1105572ZicrusNile (IOI24_nile)C++17
100 / 100
116 ms25136 KiB
#include <bits/stdc++.h> #include "nile.h" using namespace std; typedef long long ll; ll n, q, sumA, sumRun; vector<pair<ll, ll>> art; // W, B-A vector<pair<ll, ll>> d; // D, id priority_queue<pair<pair<ll, bool>, pair<ll, ll>>> gaps; // { -dist, jump }, { id1, id2 } vector<ll> lnk, sz, sumDiff, mxJmpEven, mxJmpOdd, mxEven, mxOdd; ll find(ll a) { if (lnk[a] != a) lnk[a] = find(lnk[a]); return lnk[a]; } void unite(ll a, ll b) { a = find(a); b = find(b); if (a == b) return; if (a < b) swap(a, b); lnk[a] = b; if (sz[b]&1) { // odd mxJmpEven[b] = max(mxJmpEven[b], mxJmpOdd[a]); mxJmpOdd[b] = max(mxJmpOdd[b], mxJmpEven[a]); mxEven[b] = max(mxEven[b], mxOdd[a]); mxOdd[b] = max(mxOdd[b], mxEven[a]); } else { // even mxJmpEven[b] = max(mxJmpEven[b], mxJmpEven[a]); mxJmpOdd[b] = max(mxJmpOdd[b], mxJmpOdd[a]); mxEven[b] = max(mxEven[b], mxEven[a]); mxOdd[b] = max(mxOdd[b], mxOdd[a]); } sz[b] += sz[a]; sumDiff[b] += sumDiff[a]; } ll getContrib(ll a) { a = find(a); ll res = sumDiff[a]; if (sz[a]&1) { res -= max(mxEven[a], mxJmpOdd[a]); } return res; } vector<ll> calculate_costs(vector<int> W, vector<int> A, vector<int> B, vector<int> E) { n = W.size(); q = E.size(); sumA = 0; vector<ll> res(q); art = vector<pair<ll, ll>>(n); d = vector<pair<ll, ll>>(q); for (int i = 0; i < n; i++) { art[i] = {W[i], B[i]-A[i]}; sumA += A[i]; } for (int i = 0; i < q; i++) { d[i] = {E[i], i}; } sort(art.begin(), art.end()); sort(d.begin(), d.end()); sumRun = sumA; for (int i = 0; i < n-1; i++) { gaps.push({{-(art[i+1].first - art[i].first), false}, {i, i+1}}); } for (int i = 0; i < n-2; i++) { gaps.push({{-(art[i+2].first - art[i].first), true}, {i, i+2}}); } // Union-find setup lnk = sumDiff = mxEven = vector<ll>(n); for (int i = 0; i < n; i++) { lnk[i] = i; sumDiff[i] = art[i].second; mxEven[i] = art[i].second; } sz = vector<ll>(n, 1); mxJmpEven = vector<ll>(n, -(1ll << 62ll)); mxJmpOdd = vector<ll>(n, -(1ll << 62ll)); mxOdd = vector<ll>(n, -(1ll << 62ll)); // Solve for (int i = 0; i < q; i++) { ll D = d[i].first; while (!gaps.empty() && -gaps.top().first.first <= D) { ll dist = -gaps.top().first.first; bool jmp = gaps.top().first.second; pair<ll, ll> ids = gaps.top().second; gaps.pop(); if (jmp) { ll id = (ids.first + ids.second) / 2; ll comp = find(id); ll relPos = id - comp; sumRun -= getContrib(comp); if (relPos&1) { // odd mxJmpOdd[comp] = max(mxJmpOdd[comp], art[id].second); } else { // even mxJmpEven[comp] = max(mxJmpEven[comp], art[id].second); } sumRun += getContrib(comp); } else { ids = {find(ids.first), find(ids.second)}; if (ids.first != ids.second) { sumRun -= getContrib(ids.first) + getContrib(ids.second); unite(ids.first, ids.second); sumRun += getContrib(ids.first); } } } res[d[i].second] = sumRun; } return res; } #ifdef TEST #include "grader.cpp" #endif

Compilation message (stderr)

nile.cpp: In function 'std::vector<long long int> calculate_costs(std::vector<int>, std::vector<int>, std::vector<int>, std::vector<int>)':
nile.cpp:90:16: warning: unused variable 'dist' [-Wunused-variable]
   90 |             ll dist = -gaps.top().first.first;
      |                ^~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...