Submission #1089563

#TimeUsernameProblemLanguageResultExecution timeMemory
1089563VMaksimoski008Two Currencies (JOI23_currencies)C++17
100 / 100
2775 ms238088 KiB
#include <bits/stdc++.h> #define all(x) x.begin(), x.end() #define rall(x) x.rbegin(), x.rend() //#define int long long using namespace std; using ll = long long; using pii = pair<int, int>; using pll = pair<ll, ll>; const int mod = 1e9 + 7; const int LOG = 20; const int maxn = 1e5 + 5; struct Node { ll sum, cnt; Node *l, *r; Node(ll a, ll b) : sum(a), cnt(b), l(nullptr), r(nullptr) {} Node(Node *l, Node *r) : sum(0), cnt(0), l(l), r(r) { if(l) sum += l->sum; if(l) cnt += l->cnt; if(r) sum += r->sum; if(r) cnt += r->cnt; } }; Node* build(int tl, int tr) { if(tl == tr) return new Node((ll)0, (ll)0); int tm = (tl + tr) / 2; return new Node(build(tl, tm), build(tm+1, tr)); } Node* update(Node *u, int tl, int tr, int p, ll v) { if(tl == tr) return new Node(u->sum + v, u->cnt + (v > 0 ? 1 : -1)); int tm = (tl + tr) / 2; if(p <= tm) return new Node(update(u->l, tl, tm, p, v), u->r); return new Node(u->l, update(u->r, tm+1, tr, p, v)); } pll merge(pll a, pll b) { return { a.first + b.first, a.second + b.second }; } pll query(Node *u, int tl, int tr, int l, int r) { if(tl > r || l > tr) return { 0, 0 }; if(l <= tl && tr <= r) return { u->sum, u->cnt }; int tm = (tl + tr) / 2; return merge(query(u->l, tl, tm, l, r), query(u->r, tm+1, tr, l, r)); } vector<int> graph[maxn], vals[maxn], in(maxn), out(maxn), d(maxn), euler; int timer = 0, up[maxn][20], cnt[maxn][20]; void dfs_init(int u, int p) { in[u] = timer++; euler.push_back(u); for(int i=1; i<20; i++) up[u][i] = up[up[u][i-1]][i-1]; for(int &v : graph[u]) { if(v == p) continue; d[v] = d[u] + 1; up[v][0] = u; dfs_init(v, u); } out[u] = timer++; euler.push_back(u); } void dfs2(int u, int p) { for(int i=1; i<20; i++) cnt[u][i] = cnt[u][i-1] + cnt[up[u][i-1]][i-1]; for(int &v : graph[u]) { if(v == p) continue; dfs2(v, u); } } int get_lca(int a, int b) { if(d[a] < d[b]) swap(a, b); int D = d[a] - d[b]; for(int j=19; j>=0; j--) if(D & (1 << j)) a = up[a][j]; if(a == b) return a; for(int j=19; j>=0; j--) if(up[a][j] != up[b][j]) a = up[a][j], b = up[b][j]; return up[a][0]; } int get_cnt(int a, int b) { if(d[a] < d[b]) swap(a, b); int D = d[a] - d[b], ans = 0; for(int j=19; j>=0; j--) if(D & (1 << j)) ans += cnt[a][j], a = up[a][j]; if(a == b) return ans; for(int j=19; j>=0; j--) { if(up[a][j] != up[b][j]) { ans += cnt[a][j] + cnt[b][j]; a = up[a][j]; b = up[b][j]; } } return ans + cnt[a][0] + cnt[b][0]; } signed main() { ios_base::sync_with_stdio(false); cout.tie(0); cin.tie(0); int n, m, q; cin >> n >> m >> q; vector<array<int, 2> > edges; const int B = sqrt(2 * n); for(int i=0; i<n-1; i++) { int a, b; cin >> a >> b; graph[a].push_back(b); graph[b].push_back(a); edges.push_back({ a, b }); } dfs_init(1, 1); vector<pll> vec; for(int i=0; i<m; i++) { int id, x; cin >> id >> x; if(d[edges[id-1][0]] > d[edges[id-1][1]]) swap(edges[id-1][0], edges[id-1][1]); vec.push_back({ x, edges[id-1][1] }); cnt[edges[id-1][1]][0]++; } dfs2(1, 1); sort(vec.begin(), vec.end()); vector<Node*> roots; roots.push_back(build(0, 2*n-1)); for(auto &[val, u] : vec) { roots.push_back(update(roots.back(), 0, 2*n-1, in[u], val)); roots.push_back(update(roots.back(), 0, 2*n-1, out[u], -val)); } auto Q = [&](int t, int a, int b) { pll x = query(roots[t], 0, 2*n-1, 0, in[a]); pll y = query(roots[t], 0, 2*n-1, 0, in[b]); pll z = query(roots[t], 0, 2*n-1, 0, in[get_lca(a, b)]); return pll{ x.first + y.first - 2 * z.first, x.second + y.second - 2 * z.second }; }; while(q--) { ll a, b, G, S; cin >> a >> b >> G >> S; if(a == b) { cout << G << '\n'; continue; } ll total = get_cnt(a, b); int l=0, r=m, can=0; while(l <= r) { int mid = (l + r) / 2; auto res = Q(2*mid, a, b); if(res.first <= S) can = res.second, l = mid + 1; else r = mid - 1; } cout << max(-1ll, G - total + can) << '\n'; } return 0; }

Compilation message (stderr)

currencies.cpp: In function 'int main()':
currencies.cpp:117:15: warning: unused variable 'B' [-Wunused-variable]
  117 |     const int B = sqrt(2 * n);
      |               ^
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...