Submission #1216006

#TimeUsernameProblemLanguageResultExecution timeMemory
1216006abczzJOI tour (JOI24_joitour)C++20
6 / 100
2659 ms677408 KiB
#include "joitour.h"
#include <iostream>
#include <vector>
#include <array>
#define ll long long

using namespace std;

vector <ll> adj[200000];
array <ll, 2> R[18][200000], to[18][200000];
vector <ll> rt[18];
ll C[200000], T[200000], sz[200000], dfn[18], grp, tot[18], f;
ll G[18][200000], X[18][200000], Y[18][200000], cnt[18][200000][5], Z[200000], P[18][200000];

array<ll, 2> operator+(array<ll, 2> a, array<ll, 2> b) {
  return {a[0]+b[0], a[1]+b[1]};
}
struct SegTree{
  vector <array<ll, 2> > st;
  vector <ll> A;
  void init() {
    st.resize(4*(ll)A.size());
  }
  void build(ll id, ll l, ll r) {
    if (l == r) {
      st[id] = {A[l] == 0, A[l] == 2};
      return;
    }
    ll mid = (l+r)/2;
    build(id*2, l, mid);
    build(id*2+1, mid+1, r);
    st[id] = st[id*2] + st[id*2+1];
  }
  void update(ll id, ll l, ll r, ll q, ll w) {
    if (l == r) {
      st[id] = {w == 0, w == 2};
      return;
    }
    ll mid = (l+r)/2;
    if (q <= mid) update(id*2, l, mid, q, w);
    else update(id*2+1, mid+1, r, q, w);
    st[id] = st[id*2] + st[id*2+1];
  }
  array<ll, 2> query(ll id, ll l, ll r, ll ql, ll qr) {
    if (qr < l || r < ql) return {0, 0};
    else if (ql <= l && r <= qr) return st[id];
    ll mid = (l+r)/2;
    return query(id*2, l, mid, ql, qr) + query(id*2+1, mid+1, r, ql, qr);
  }
}ST[18], tmp;
vector <SegTree> ST_hld[18];
ll dfs_sz(ll u, ll p, ll w) {
  sz[u] = 0;
  ll mx = -1, ret = -1;
  for (auto v : adj[u]) {
    if (C[v] != -1 || v == p) continue;
    ret = max(ret, dfs_sz(v, u, w));
    mx = max(mx, sz[v]);
    sz[u] += sz[v];
  }
  mx = max(mx, w-1-sz[u]);
  ++sz[u];
  if (mx <= w/2) return u;
  return ret;
}
void dfs_order(ll u, ll p, ll k) {
  G[k][u] = (p == -1 ? u : G[k][p]);
  sz[u] = 0, R[k][u][0] = dfn[k];
  for (auto v : adj[u]) {
    if (C[v] != -1 || v == p) continue;
    dfs_order(v, u, k);
    sz[u] += sz[v];
  }
  ST[k].A.push_back(T[u]);
  ++sz[u], R[k][u][1] = dfn[k]++;
}
void dfs_hld(ll u, ll p, ll k, ll z) {
  ll mx = -1, id = -1;
  X[z][u] = k, Y[z][u] = ST_hld[z][k].A.size();
  ST_hld[z][k].A.push_back(T[u]-1);
  for (auto v : adj[u]) {
    if (C[v] <= k || v == p) continue;
    if (mx < sz[v]) mx = sz[v], id = v;
  }
  for (auto v : adj[u]) {
    if (C[v] <= k || v == p || v == id) continue;
    ST_hld[z].push_back(tmp);
    P[z][tot[z]] = u;
    dfs_hld(v, u, tot[z]++, z);
  }
  if (id != -1) dfs_hld(id, u, k, z);
}
void dfs_centroid(ll u, ll k, ll w) {
  grp = max(grp, k);
  ll cent = dfs_sz(u, -1, w);
  dfs_order(cent, -1, k);
  C[cent] = k;
  rt[k].push_back(cent);
  for (auto v : adj[cent]) {
    if (C[v] != -1) continue;
    dfs_centroid(v, k+1, sz[v]);
  }
}
array<ll, 2> dfs_calc(ll u, ll p, ll k) {
  array<ll, 2> ret = {0LL, 0LL};
  for (auto v : adj[u]) {
    if (v == p || C[v] <= k) continue;
    ret = ret + dfs_calc(v, u, k);
  }
  if (T[u] == 1) {
    auto w = ST[k].query(1, 0, ST[k].A.size()-1, R[k][u][0], R[k][u][1]);
    ret = ret + (array<ll, 2>){w[0], w[1]};
  }
  return ret;
}
void recalc() {
  for (int i=grp; i>=0; --i) {
    for (auto x : rt[i]) {
      ++cnt[i][x][T[x]];
      if (i == grp) continue;
      for (auto v : adj[x]) {
        if (C[v] > i) {
          for (int j=0; j<3; ++j) cnt[i][x][j] += cnt[i+1][G[i+1][v]][j];
        }
      }
      for (auto v : adj[x]) {
        if (C[v] > i) {
          Z[x] += cnt[i+1][G[i+1][v]][0] * cnt[i+1][G[i+1][v]][2];
          auto w = dfs_calc(v, x, i);
          to[i][G[i+1][v]] = {w[0], w[1]};
          f += w[0] * (cnt[i][x][2]-cnt[i+1][G[i+1][v]][2]);
          f += w[1] * (cnt[i][x][0]-cnt[i+1][G[i+1][v]][0]);
          cnt[i][x][3] += w[0], cnt[i][x][4] += w[1];
        }
      }
      if (T[x] == 1) f += cnt[i][x][0] * cnt[i][x][2] - Z[x];
    }
  }
}
ll jump(ll z, ll u) {
  ll res = 0;
  while (u != -1) {
    res += ST_hld[z][X[z][u]].query(1, 0, ST_hld[z][X[z][u]].A.size()-1, 0, Y[z][u])[0];
    u = P[z][X[z][u]];
  }
  return res;
}
void init(int N, std::vector<int> F, std::vector<int> U, std::vector<int> V, int Q) {
  for (int i=0; i<N; ++i) C[i] = -1, T[i] = F[i];
  for (int i=0; i<N-1; ++i) {
    adj[U[i]].push_back(V[i]);
    adj[V[i]].push_back(U[i]);
  }
  dfs_centroid(0, 0, N);
  for (int i=grp; i>=0; --i) {
    for (auto x : rt[i]) {
      P[i][tot[i]] = -1;
      ST_hld[i].push_back(tmp);
      dfs_hld(x, -1, tot[i]++, i);
    }
  }
  for (int i=0; i<=grp; ++i) {
    ST[i].init();
    ST[i].build(1, 0, ST[i].A.size()-1);
    for (int j=0; j<tot[i]; ++j) {
      ST_hld[i][j].init();
      ST_hld[i][j].build(1, 0, ST_hld[i][j].A.size()-1);
    }
  }
  recalc();
}

void change(int x, int y) {
  if (T[x] == y) return;
  for (int i=0; i<=grp; ++i) {
    if (i != grp) Z[G[i][x]] -= cnt[i+1][G[i+1][x]][0] * cnt[i+1][G[i+1][x]][2];
    if (G[i][x] == x) {
      if (T[x] == 1) f -= cnt[i][x][0] * cnt[i][x][2] - Z[x];
      else if (T[x] == 0) f -= cnt[i][x][4];
      else f -= cnt[i][x][3];
      break;
    }
    else {
      if (T[x] == 1) {
        auto w = ST[i].query(1, 0, ST[i].A.size()-1, R[i][x][0], R[i][x][1]);
        f -= (cnt[i][G[i][x]][0] - cnt[i+1][G[i+1][x]][0]) * w[1] + (cnt[i][G[i][x]][2] - cnt[i+1][G[i+1][x]][2]) * w[0];
        cnt[i][G[i][x]][3] -= w[0], to[i][G[i+1][x]][0] -= w[0];
        cnt[i][G[i][x]][4] -= w[1], to[i][G[i+1][x]][1] -= w[1];
      }
      else if (T[x] == 0) {
        auto w = jump(i, x);
        f -= (cnt[i][G[i][x]][2] - cnt[i+1][G[i+1][x]][2]) * w + (cnt[i][G[i][x]][4] - to[i][G[i+1][x]][1]);
        cnt[i][G[i][x]][3] -= w-(T[G[i][x]] == 1), to[i][G[i+1][x]][0] -= w-(T[G[i][x]] == 1);
      }
      else {
        auto w = jump(i, x);
        f -= (cnt[i][G[i][x]][0] - cnt[i+1][G[i+1][x]][0]) * w + (cnt[i][G[i][x]][3] - to[i][G[i+1][x]][0]);
        cnt[i][G[i][x]][4] -= w-(T[G[i][x]] == 1), to[i][G[i+1][x]][1] -= w-(T[G[i][x]] == 1);
      }
    }
  }
  for (int i=C[x]; i>=0; --i) {
    ST[i].update(1, 0, ST[i].A.size()-1, R[i][x][1], y);
    ST_hld[i][X[i][x]].update(1, 0, ST_hld[i][X[i][x]].A.size()-1, Y[i][x], y-1);
    --cnt[i][G[i][x]][T[x]], ++cnt[i][G[i][x]][y];
  }
  T[x] = y;
  for (int i=0; i<=grp; ++i) {
    if (i != grp) Z[G[i][x]] += cnt[i+1][G[i+1][x]][0] * cnt[i+1][G[i+1][x]][2];
    if (G[i][x] == x) {
      if (T[x] == 1) f += cnt[i][x][0] * cnt[i][x][2] - Z[x];
      else if (T[x] == 0) f += cnt[i][x][4];
      else f += cnt[i][x][3];
      break;
    }
    else {
      if (T[x] == 1) {
        auto w = ST[i].query(1, 0, ST[i].A.size()-1, R[i][x][0], R[i][x][1]);
        f += (cnt[i][G[i][x]][0] - cnt[i+1][G[i+1][x]][0]) * w[1] + (cnt[i][G[i][x]][2] - cnt[i+1][G[i+1][x]][2]) * w[0];
        cnt[i][G[i][x]][3] += w[0], to[i][G[i+1][x]][0] += w[0];
        cnt[i][G[i][x]][4] += w[1], to[i][G[i+1][x]][1] += w[1];
      }
      else if (T[x] == 0) {
        auto w = jump(i, x);
        f += (cnt[i][G[i][x]][2] - cnt[i+1][G[i+1][x]][2]) * w + (cnt[i][G[i][x]][4] - to[i][G[i+1][x]][1]);
        cnt[i][G[i][x]][3] += w-(T[G[i][x]] == 1), to[i][G[i+1][x]][0] += w-(T[G[i][x]] == 1);
      }
      else {
        auto w = jump(i, x);
        f += (cnt[i][G[i][x]][0] - cnt[i+1][G[i+1][x]][0]) * w + (cnt[i][G[i][x]][3] - to[i][G[i+1][x]][0]);
        cnt[i][G[i][x]][4] += w-(T[G[i][x]] == 1), to[i][G[i+1][x]][1] += w-(T[G[i][x]] == 1);
      }
    }
  }
}

long long num_tours() {
  return f;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...