제출 #1178928

#제출 시각아이디문제언어결과실행 시간메모리
1178928rxlfd314JOI tour (JOI24_joitour)C++20
100 / 100
2373 ms86484 KiB
#include "joitour.h"
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
using ari2 = array<int, 2>;
using ari3 = array<int, 3>;
using arl2 = array<ll, 2>;
using arl3 = array<ll, 3>;

#define vt vector
#define all(x) begin(x), end(x)
#define size(x) (int((x).size()))

#define REP(a, b, c, d) for (int a = (b); (d) > 0 ? a <= (c) : a >= (c); a += (d))
#define FOR(a, b, c) REP(a, b, c, 1)
#define ROF(a, b, c) REP(a, b, c, -1)

struct Node {
  ll sum, all_sum;
  int cnt, lz;
  Node operator+(const Node other) {
    return {sum + other.sum, all_sum + other.all_sum, cnt + other.cnt, 0};
  }
};

constexpr int mxN = 200000;
int N, F[mxN], cnt[3];
vt<int> adj[mxN];

struct ST {
  Node st[800000];
  #define lc i << 1
  #define rc lc | 1
  void push(int i, int tl, int tr) {
    if (!st[i].lz)
      return;
    int tm = tl + tr >> 1;
    st[lc].sum += 1ll * st[lc].cnt * st[i].lz;
    st[lc].all_sum += 1ll * (tm - tl + 1) * st[i].lz;
    st[lc].lz += st[i].lz;
    st[rc].sum += 1ll * st[rc].cnt * st[i].lz;
    st[rc].all_sum += 1ll * (tr - tm) * st[i].lz;
    st[rc].lz += st[i].lz;
    st[i].lz = 0;
  }
  void radd(int ql, int qr, int v, int i = 1, int tl = 0, int tr = N-1) {
    if (tl > qr || tr < ql)
      return;
    if (ql <= tl && tr <= qr) {
      st[i].sum += st[i].cnt * v;
      st[i].all_sum += v * (tr - tl + 1);
      st[i].lz += v;
    } else {
      push(i, tl, tr);
      int tm = tl + tr >> 1;
      radd(ql, qr, v, lc, tl, tm);
      radd(ql, qr, v, rc, tm+1, tr);
      st[i] = st[lc] + st[rc];
    }
  }
  void update(int p, int v, int i = 1, int tl = 0, int tr = N-1) {
    while (tl < tr) {
      push(i, tl, tr);
      int tm = tl + tr >> 1;
      if (p <= tm)
        i = lc, tr = tm;
      else
        i = rc, tl = tm + 1;
    }
    st[i].cnt += v;
    st[i].sum = st[i].all_sum * st[i].cnt;
    for (i >>= 1; i > 0; i >>= 1)
      st[i] = st[lc] + st[rc];
  }
  Node query(int ql, int qr, int i = 1, int tl = 0, int tr = N-1) {
    if (tl > qr || tr < ql)
      return {0, 0, 0, 0};
    if (ql <= tl && tr <= qr)
      return st[i];
    push(i, tl, tr);
    int tm = tl + tr >> 1;
    return query(ql, qr, lc, tl, tm) + query(ql, qr, rc, tm+1, tr);
  }
  #undef lc
  #undef rc
};

int szs[mxN], parent[mxN], hson[mxN];
void dfs_szs(int i, int p) {
  parent[i] = p;
  szs[i] = 1;
  hson[i] = -1;
  for (int j : adj[i])
    if (j != p) {
      dfs_szs(j, i);
      szs[i] += szs[j];
      if (hson[i] < 0 || szs[j] > szs[hson[i]])
        hson[i] = j;
    }
}

int head[mxN], tin[mxN], tout[mxN], timer;
void dfs_hld(int i) {
  tin[i] = timer++;
  if (hson[i] >= 0) {
    head[hson[i]] = head[i];
    dfs_hld(hson[i]);
  }
  for (int j : adj[i])
    if (j != parent[i] && j != hson[i]) {
      head[j] = j;
      dfs_hld(j);
    }
  tout[i] = timer - 1;
}

ll ans, sum[mxN];
ST above_st[3], below_st[3];
void update(int i, int v) {
  const int c = F[i], oc = c ^ 2;
  above_st[c].radd(0, N-1, v);
  for (; i >= 0; i = parent[head[i]]) {
    int h = head[i];
    above_st[c].radd(tin[h], tin[i], -v);
    below_st[c].radd(tin[h], tin[i], v);
  }
}

void add(int i, int v) {
  ll bef_ans = ans;
  const int c = F[i], oc = c ^ 2;
  ans += above_st[oc].query(0, N-1).sum * v;
  for (; i >= 0; i = parent[head[i]]) {
    int h = head[i];
    ans -= above_st[oc].query(tin[h], tin[i]).sum * v;
    ans += below_st[oc].query(tin[h], tin[i]).sum * v;
    if (h) {
      int x = below_st[oc].query(tin[h], tin[h]).all_sum;
      sum[parent[h]] += x * v;
      if (F[parent[h]] == 1)
        ans += x * v;
    }
  }
}

void init(int32_t _N, vt<int32_t> _F, vt<int32_t> U, vt<int32_t> V, int32_t Q) {
  N = _N;
  FOR(i, 0, N-1) {
    F[i] = _F[i];
    cnt[F[i]]++;
  }
  FOR(i, 0, N-2) {
    adj[U[i]].push_back(V[i]);
    adj[V[i]].push_back(U[i]);
  }
  dfs_szs(0, -1);
  dfs_hld(0);

  FOR(i, 0, N-1)
    if (F[i] == 1) {
      above_st[0].update(tin[i], 1);
      above_st[2].update(tin[i], 1);
      if (hson[i] >= 0) {
        below_st[0].update(tin[hson[i]], 1);
        below_st[2].update(tin[hson[i]], 1);
      }
    }
  FOR(i, 0, N-1)
    if (F[i] != 1) {
      update(i, 1);
      add(i, 1);
    } 
}

void change(int32_t i, int32_t c) {
  if (F[i] == 1) {
    ans -= sum[i];
    ans -= 1ll * above_st[0].query(tin[i], tin[i]).sum * above_st[2].query(tin[i], tin[i]).sum;
    if (hson[i] >= 0) {
      int h = hson[i];
      ans -= 1ll * below_st[0].query(tin[h], tin[h]).all_sum * below_st[2].query(tin[h], tin[h]).all_sum;
    }
    above_st[0].update(tin[i], -1);
    above_st[2].update(tin[i], -1);
    if (hson[i] >= 0) {
      int h = hson[i];
      below_st[0].update(tin[h], -1);
      below_st[2].update(tin[h], -1);
    }
  } else {
    add(i, -1);
    update(i, -1);
  }
  cnt[F[i]]--;
  cnt[F[i]=c]++;
  if (F[i] == 1) {
    above_st[0].update(tin[i], 1);
    above_st[2].update(tin[i], 1);
    if (hson[i] >= 0) {
      int h = hson[i];
      below_st[0].update(tin[h], 1);
      below_st[2].update(tin[h], 1);
    }
    ans += sum[i];
    ans += 1ll * above_st[0].query(tin[i], tin[i]).sum * above_st[2].query(tin[i], tin[i]).sum;
    if (hson[i] >= 0) {
      int h = hson[i];
      ans += 1ll * below_st[0].query(tin[h], tin[h]).all_sum * below_st[2].query(tin[h], tin[h]).all_sum;
    }
  } else {
    update(i, 1);
    add(i, 1);
  }
}

ll num_tours() {
  return 1ll * cnt[0] * cnt[1] * cnt[2] - ans;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...