Submission #1356220

#TimeUsernameProblemLanguageResultExecution timeMemory
1356220avighnaTwo Currencies (JOI23_currencies)C++20
100 / 100
2964 ms515440 KiB
#include <bits/stdc++.h>

using namespace std;

class wavelet_tree {
  struct node {
    int64_t tl, tr;
    vector<int> left_l, left_r, right_l, right_r;
    // vector<int> map_left, map_right;
    vector<int64_t> pref;
    node *l, *r;
    node(int64_t tl, int64_t tr) : tl(tl), tr(tr), l(nullptr), r(nullptr) {}
    ~node() {
      delete l;
      delete r;
    }
  };

  node *t;

  template <typename Fn>
  void fill(vector<int>::iterator itl, vector<int>::iterator itr, vector<int> &map_left, vector<int> &map_right, const Fn &f) {
    int l = -1, r = 0;
    for (auto it = itl; it != itr; ++it) {
      if (f(*it)) {
        r++;
        map_left.push_back(++l);
      } else {
        map_left.push_back(l + 1);
      }
    }
    int del = itr - itl;
    r--;
    for (int i = del - 1; i >= 0; --i) {
      auto it = itl + i;
      if (f(*it)) {
        map_right.push_back(r--);
      } else {
        map_right.push_back(r);
      }
    }
    reverse(map_right.begin(), map_right.end());
  }

  void make(node *t, vector<int>::iterator itl, vector<int>::iterator itr) {
    int64_t tm = midpoint(t->tl, t->tr);
    t->pref.push_back(0);
    for (auto it = itl; it != itr; ++it) {
      t->pref.push_back(t->pref.back() + *it);
    }

    fill(itl, itr, t->left_l, t->left_r, [&](int x) { return x <= tm; });
    fill(itl, itr, t->right_l, t->right_r, [&](int x) { return x > tm; });

    if (t->tl == t->tr) {
      return;
    }
    auto itm = stable_partition(itl, itr, [&](int x) { return x <= tm; });
    if (itl != itm) {
      t->l = new node(t->tl, tm);
      make(t->l, itl, itm);
    }
    if (itm != itr) {
      t->r = new node(tm + 1, t->tr);
      make(t->r, itm, itr);
    }
  }

  // we need a data structure that, for a static array, given a range [l,r] and a value x
  // tells us how many values in a[i] in [l,r] are <= x

  // [3,1,5,0,2]
  //  0 1 2 3 4

  // how many values in [1,3] are <= 3

  // split from value range [0-5] to [0-2] [3-5]

  //              [1,0,2] and [3,5]
  // transformed:  0 1 2       0 1
  // original:     1 3 4       0 2

  // map_left[i] = how many values from the first i are going to the left?
  // [0, 0, 1, 1, 2, 3]

  int num_le(node *t, int l, int r, int x) {
    if (t == nullptr || l > r || x < t->tl) {
      return 0;
    }
    if (t->tr <= x) {
      return r - l + 1;
    }
    return num_le(t->l, t->left_l[l], t->left_r[r], x) + num_le(t->r, t->right_l[l], t->right_r[r], x);
  }

  int64_t sum_le(node *t, int l, int r, int x) {
    if (t == nullptr || l > r || x < t->tl) {
      return 0;
    }
    if (t->tr <= x) {
      return t->pref[r + 1] - t->pref[l];
    }
    return sum_le(t->l, t->left_l[l], t->left_r[r], x) + sum_le(t->r, t->right_l[l], t->right_r[r], x);
  }

public:
  ~wavelet_tree() { delete t; }
  wavelet_tree(vector<int> arr, int w) {
    t = new node(0, 1ll << w);
    make(t, arr.begin(), arr.end());
  }

  int num_le(int l, int r, int x) { return num_le(t, l, r, x); }
  int64_t sum_le(int l, int r, int x) { return sum_le(t, l, r, x); }
};

int main() {
  cin.tie(nullptr)->sync_with_stdio(false);

  // for a fixed path, assume we spend 0 gold coins
  // then we just need to sum(silver)
  // otherwise, we spend 1 gold coin(s) and sum(silver) - biggest1(silver)
  //            we spend 2 gold coin(s) and sum(silver) - biggest2(silver)

  // so find the first x such that biggestx(silver) <= S - sum(silver)

  // let number of checkpoints be C, weight of each be c

  // first x such that x*c <= S - C*c
  // x = (S-C*c)/c

  int n, m, q;
  cin >> n >> m >> q;
  vector<vector<pair<int, int>>> adj(n + 1);
  for (int i = 0, u, v; i < n - 1; ++i) {
    cin >> u >> v;
    adj[u].push_back({v, i});
    adj[v].push_back({u, i});
  }

  // hld
  vector<int> subtree_sz(n + 1);
  auto dfs_sz = [&](auto &&self, int u, int p) -> void {
    if (adj[u].size() > 1 && adj[u][0].first == p) {
      swap(adj[u][0], adj[u][1]);
    }
    subtree_sz[u] = 1;
    for (auto &ele : adj[u]) {
      auto &[i, _] = ele;
      if (i == p) {
        continue;
      }
      self(self, i, u);
      subtree_sz[u] += subtree_sz[i];
      if (subtree_sz[i] > subtree_sz[adj[u][0].first]) {
        swap(adj[u][0], ele);
      }
    }
  };
  dfs_sz(dfs_sz, 1, 0);
  int timer = 0;
  vector<int> tin(n + 1), node_with_tin(n), up(n + 1), par(n + 1), node_of(n - 1);
  auto dfs_hld = [&](auto &&self, int u, int p) -> void {
    node_with_tin[timer] = u;
    tin[u] = timer++;
    for (auto &ele : adj[u]) {
      auto &[i, _] = ele;
      if (i == p) {
        continue;
      }
      node_of[_] = i;
      par[i] = u;
      up[i] = i == adj[u][0].first ? up[u] : i;
      self(self, i, u);
    }
  };
  up[1] = 1;
  dfs_hld(dfs_hld, 1, 0);

  vector<vector<int>> cs(n + 1);
  for (int i = 0, p, c; i < m; ++i) {
    cin >> p >> c;
    --p;
    cs[node_of[p]].push_back(c);
  }

  vector<int> pref(n);
  for (int ti = 0; ti < n; ++ti) {
    pref[ti] = cs[node_with_tin[ti]].size();
  }
  partial_sum(pref.begin(), pref.end(), pref.begin());

  auto translate = [&](int ti) {
    return pref[ti] - cs[node_with_tin[ti]].size();
  };

  vector<int> arr;
  for (int ti = 0; ti < n; ++ti) {
    for (int &c : cs[node_with_tin[ti]]) {
      arr.push_back(c);
    }
  }
  wavelet_tree ds(arr, 31);
  auto transform_lr = [&](int til, int tir) -> pair<int, int> {
    til = translate(til);
    tir = translate(tir) + cs[node_with_tin[tir]].size() - 1;
    return {til, tir};
  };
  auto query_sum_le = [&](int til, int tir, int x) -> int64_t {
    auto [l, r] = transform_lr(til, tir);
    if (l > r) {
      return 0;
    }
    return ds.sum_le(l, r, x);
  };
  auto query_num_le = [&](int til, int tir, int x) {
    auto [l, r] = transform_lr(til, tir);
    if (l > r) {
      return 0;
    }
    return ds.num_le(l, r, x);
  };

  auto get_chains_hld = [&](int u, int v) {
    vector<pair<int, int>> ans;
    while (u != v) {
      if (tin[u] < tin[v]) { // u is deeper
        swap(u, v);
      }
      if (up[u] == up[v]) {
        ans.push_back({tin[v] + 1, tin[u]});
        return ans;
      }
      ans.push_back({tin[up[u]], tin[u]});
      u = par[up[u]];
    }
    return ans;
  };

  // find a number v such that summing, across all chains, all the values <= v
  // we have total sum <= our silver
  // find the largest such v (and don't forget to include partials at v+1)
  // then count how many checkpoints we got

  while (q--) {
    int64_t s, t, x, y;
    cin >> s >> t >> x >> y;
    auto chains = get_chains_hld(s, t);
    auto query_chains_sum = [&](int v) {
      int64_t sum = 0;
      for (auto &[l, r] : chains) {
        sum += query_sum_le(l, r, v);
      }
      return sum;
    };
    auto query_chains_num = [&](int v) {
      int tot = 0;
      for (auto &[l, r] : chains) {
        tot += query_num_le(l, r, v);
      }
      return tot;
    };
    int v = *ranges::partition_point(views::iota(int(0), int(1e9)), [&](int v) {
      return query_chains_sum(v) <= y;
    }) - 1;

    int cleared = query_chains_num(v);
    int num_vp1 = query_chains_num(v + 1) - cleared;
    y -= query_chains_sum(v);
    // y/(v+1)
    cleared += min(int64_t(num_vp1), y / (v + 1));
    int gold_needed = query_chains_num(int(1e9)) - cleared;

    if (x < gold_needed) {
      cout << "-1\n";
    } else {
      cout << x - gold_needed << '\n';
    }
  }
}
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...