Submission #1118396

#TimeUsernameProblemLanguageResultExecution timeMemory
1118396rama_pangTree (IOI24_tree)C++17
100 / 100
148 ms22948 KiB
#if not LOCAL
#define NDEBUG 1
#endif

#include <bits/stdc++.h>
using namespace std;

#define rep(i, a, b) for(auto i = a; i < (b); ++i)
#define down(x, a) for (auto x = a; x--;)
#define all(x) begin(x), end(x)
#define sz(x) int(size(x))
#define let auto const

using ll = long long;
using lint = ll;
using pii = pair<int, int>;
using vi = vector<int>;

vector<lint> answer;

struct DisjointSet {
  DisjointSet(int n) : dsu(n, -1), cnt(n) {}

  vector<int> dsu;
  vector<int> cnt;

  bool merge(int x, int y) {
    if (find(x) == find(y)) {
      return false;
    }
    x = find(x), y = find(y);
    dsu[x] = y;
    cnt[y] += cnt[x];
    return true;
  }

  int updateCount(int x, int dif) {
    return cnt[find(x)] += dif;
  }

  int getCount(int x) {
    return cnt[find(x)];
  }

  int find(int x) {
    return dsu[x] < 0 ? x : dsu[x] = find(dsu[x]);
  }
};

void init(vector<int> P, vector<int> W) {
  int N = sz(P) + 1;
  P.insert(begin(P), -1);
  W.insert(begin(W), 0);
  rep(i, 1, N) P[i]++;

  vector<int> cntLeaf(N, 0);
  vector<vector<int>> adj(N);
  rep(i,1,N) adj[P[i]].push_back(i);

  lint baseAns = 0;
  down(i, N) {
    if (sz(adj[i]) == 0) {
      cntLeaf[i] = 1;
      baseAns += W[i];
    } else {
      cntLeaf[i] = 0;
      for (auto j : adj[i]) {
        cntLeaf[i] += cntLeaf[j];
      }
    }
  }

  vector<int> ord(N);
  iota(all(ord), 0);
  sort(all(ord), [&](int i, int j) {
    return pair(W[i], i) > pair(W[j], j);
  });
  assert(end(ord)[-1] == 0);

  DisjointSet dsu(N);
  vector<int> done(N);
  answer.resize(cntLeaf[0] + 1);
  auto UpdateAns = [&](lint l, lint r, lint d) {
    if (l < sz(answer)) answer[l] += d;
    if (r < sz(answer)) answer[r] -= d;
  };
  for (auto u : ord) {
    if (sz(adj[u]) == 0 || u == 0) {
      continue;
    }
    done[u] = 1;
    vector<int> ls;
    for (auto v : adj[u]) {
      if (done[v]) {
        ls.push_back(dsu.getCount(v));
        dsu.merge(v, u);
      }
    };
    if (done[P[u]]) {
      ls.push_back(dsu.getCount(P[u]));
      dsu.merge(u, P[u]);
    }
    sort(all(ls));
    lint ptr = 1;
    lint total = 0;
    rep(i, 0, sz(ls)) {
      if (ptr <= ls[i]) {
        UpdateAns(ptr, ls[i] + 1, -1ll * W[u] * (sz(ls) - i - 1));
        total += 1ll * (ls[i] + 1 - ptr) * (sz(ls) - i - 1);
        ptr = ls[i] + 1;
      }
    }
    total += sz(adj[u]) - 1;
    UpdateAns(ptr, ptr + total, W[u]);
    dsu.updateCount(u, sz(adj[u]) - 1);
  }
  rep(i, 1, sz(answer)) {
    answer[i] += answer[i - 1];
  }
  down(i, sz(answer)) {
    if (i + 1 < sz(answer)) {
      answer[i] += answer[i + 1];
    } else {
      answer[i] += baseAns;
    }
  }
}

long long query(int L, int R) {
  if (L > 1) {
    lint upper = query(1, (R + L - 1) / L);
    lint lower = query(1, R / L);
    return L * upper + ((L - R % L) % L) * (lower - upper);
  }
  assert(L == 1 && R >= 1);
  return answer[min(R, sz(answer) - 1)];
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...