#include "joitour.h"
#include <iostream>
#include <vector>
#include <array>
#define ll long long
using namespace std;
vector <ll> adj[200000];
array <ll, 2> R[18][200000], to[18][200000];
vector <ll> rt[18];
ll C[200000], T[200000], sz[200000], dfn[18], grp, tot[18], f;
ll G[18][200000], X[18][200000], Y[18][200000], cnt[18][200000][5], Z[200000], P[18][200000];
array<ll, 2> operator+(array<ll, 2> a, array<ll, 2> b) {
return {a[0]+b[0], a[1]+b[1]};
}
struct SegTree{
vector <array<ll, 2> > st;
vector <ll> A;
void init() {
st.resize(4*(ll)A.size());
}
void build(ll id, ll l, ll r) {
if (l == r) {
st[id] = {A[l] == 0, A[l] == 2};
return;
}
ll mid = (l+r)/2;
build(id*2, l, mid);
build(id*2+1, mid+1, r);
st[id] = st[id*2] + st[id*2+1];
}
void update(ll id, ll l, ll r, ll q, ll w) {
if (l == r) {
st[id] = {w == 0, w == 2};
return;
}
ll mid = (l+r)/2;
if (q <= mid) update(id*2, l, mid, q, w);
else update(id*2+1, mid+1, r, q, w);
st[id] = st[id*2] + st[id*2+1];
}
array<ll, 2> query(ll id, ll l, ll r, ll ql, ll qr) {
if (qr < l || r < ql) return {0, 0};
else if (ql <= l && r <= qr) return st[id];
ll mid = (l+r)/2;
return query(id*2, l, mid, ql, qr) + query(id*2+1, mid+1, r, ql, qr);
}
}ST[18], tmp;
vector <SegTree> ST_hld[18];
ll dfs_sz(ll u, ll p, ll w) {
sz[u] = 0;
ll mx = -1, ret = -1;
for (auto v : adj[u]) {
if (C[v] != -1 || v == p) continue;
ret = max(ret, dfs_sz(v, u, w));
mx = max(mx, sz[v]);
sz[u] += sz[v];
}
mx = max(mx, w-1-sz[u]);
++sz[u];
if (mx <= w/2) return u;
return ret;
}
void dfs_order(ll u, ll p, ll k) {
G[k][u] = (p == -1 ? u : G[k][p]);
sz[u] = 0, R[k][u][0] = dfn[k];
for (auto v : adj[u]) {
if (C[v] != -1 || v == p) continue;
dfs_order(v, u, k);
sz[u] += sz[v];
}
ST[k].A.push_back(T[u]);
++sz[u], R[k][u][1] = dfn[k]++;
}
void dfs_hld(ll u, ll p, ll k, ll z) {
ll mx = -1, id = -1;
X[z][u] = k, Y[z][u] = ST_hld[z][k].A.size();
ST_hld[z][k].A.push_back(T[u]-1);
for (auto v : adj[u]) {
if (C[v] <= k || v == p) continue;
if (mx < sz[v]) mx = sz[v], id = v;
}
for (auto v : adj[u]) {
if (C[v] <= k || v == p || v == id) continue;
ST_hld[z].push_back(tmp);
P[z][tot[z]] = u;
dfs_hld(v, u, tot[z]++, z);
}
if (id != -1) dfs_hld(id, u, k, z);
}
void dfs_centroid(ll u, ll k, ll w) {
grp = max(grp, k);
ll cent = dfs_sz(u, -1, w);
dfs_order(cent, -1, k);
C[cent] = k;
rt[k].push_back(cent);
for (auto v : adj[cent]) {
if (C[v] != -1) continue;
dfs_centroid(v, k+1, sz[v]);
}
}
array<ll, 2> dfs_calc(ll u, ll p, ll k) {
array<ll, 2> ret = {0LL, 0LL};
for (auto v : adj[u]) {
if (v == p || C[v] <= k) continue;
ret = ret + dfs_calc(v, u, k);
}
if (T[u] == 1) {
auto w = ST[k].query(1, 0, ST[k].A.size()-1, R[k][u][0], R[k][u][1]);
ret = ret + (array<ll, 2>){w[0], w[1]};
}
return ret;
}
void recalc() {
for (int i=grp; i>=0; --i) {
for (auto x : rt[i]) {
++cnt[i][x][T[x]];
if (i == grp) continue;
for (auto v : adj[x]) {
if (C[v] > i) {
for (int j=0; j<3; ++j) cnt[i][x][j] += cnt[i+1][G[i+1][v]][j];
Z[x] += cnt[i+1][G[i+1][v]][0] * cnt[i+1][G[i+1][v]][2];
auto w = dfs_calc(v, x, i);
to[i][G[i+1][v]] = {w[0], w[1]};
f += w[0] * (cnt[i][x][2]-cnt[i+1][G[i+1][v]][2]);
f += w[1] * (cnt[i][x][0]-cnt[i+1][G[i+1][v]][0]);
cnt[i][x][3] += w[0], cnt[i][x][4] += w[1];
}
}
if (T[x] == 1) f += cnt[i][x][0] * cnt[i][x][2] - Z[x];
}
}
}
ll jump(ll z, ll u) {
ll res = 0;
while (u != -1) {
res += ST_hld[z][X[z][u]].query(1, 0, ST_hld[z][X[z][u]].A.size()-1, 0, Y[z][u])[0];
u = P[z][X[z][u]];
}
return res;
}
void init(int N, std::vector<int> F, std::vector<int> U, std::vector<int> V, int Q) {
for (int i=0; i<N; ++i) C[i] = -1, T[i] = F[i];
for (int i=0; i<N-1; ++i) {
adj[U[i]].push_back(V[i]);
adj[V[i]].push_back(U[i]);
}
dfs_centroid(0, 0, N);
for (int i=grp; i>=0; --i) {
for (auto x : rt[i]) {
P[i][tot[i]] = -1;
ST_hld[i].push_back(tmp);
dfs_hld(x, -1, tot[i]++, i);
}
}
for (int i=0; i<=grp; ++i) {
ST[i].init();
ST[i].build(1, 0, ST[i].A.size()-1);
for (int j=0; j<tot[i]; ++j) {
ST_hld[i][j].init();
ST_hld[i][j].build(1, 0, ST_hld[i][j].A.size()-1);
}
}
recalc();
}
void change(int x, int y) {
if (T[x] == y) return;
for (int i=0; i<=grp; ++i) {
if (G[i][x] == x) {
if (T[x] == 1) f -= cnt[i][x][0] * cnt[i][x][2] - Z[x];
else if (T[x] == 0) f -= cnt[i][x][4];
else f -= cnt[i][x][3];
break;
}
else {
if (T[x] == 1) {
auto w = ST[i].query(1, 0, ST[i].A.size()-1, R[i][x][0], R[i][x][1]);
f -= (cnt[i][G[i][x]][0] - cnt[i+1][G[i+1][x]][0]) * w[1] + (cnt[i][G[i][x]][2] - cnt[i+1][G[i+1][x]][2]) * w[0];
cnt[i][G[i][x]][3] -= w[0], to[i][G[i+1][x]][0] -= w[0];
cnt[i][G[i][x]][4] -= w[1], to[i][G[i+1][x]][1] -= w[1];
}
else if (T[x] == 0) {
auto w = jump(i, x);
f -= (cnt[i][G[i][x]][2] - cnt[i+1][G[i+1][x]][2]) * w + (cnt[i][G[i][x]][4] - to[i][G[i+1][x]][1]);
cnt[i][G[i][x]][3] -= w-(T[G[i][x]] == 1), to[i][G[i+1][x]][0] -= w-(T[G[i][x]] == 1);
}
else {
auto w = jump(i, x);
f -= (cnt[i][G[i][x]][0] - cnt[i+1][G[i+1][x]][0]) * w + (cnt[i][G[i][x]][3] - to[i][G[i+1][x]][0]);
cnt[i][G[i][x]][4] -= w-(T[G[i][x]] == 1), to[i][G[i+1][x]][1] -= w-(T[G[i][x]] == 1);
}
}
Z[G[i][x]] -= cnt[i+1][G[i+1][x]][0] * cnt[i+1][G[i+1][x]][2];
}
for (int i=C[x]; i>=0; --i) {
ST[i].update(1, 0, ST[i].A.size()-1, R[i][x][1], y);
ST_hld[i][X[i][x]].update(1, 0, ST_hld[i][X[i][x]].A.size()-1, Y[i][x], y-1);
--cnt[i][G[i][x]][T[x]], ++cnt[i][G[i][x]][y];
}
T[x] = y;
for (int i=0; i<=grp; ++i) {
if (G[i][x] == x) {
if (T[x] == 1) f += cnt[i][x][0] * cnt[i][x][2] - Z[x];
else if (T[x] == 0) f += cnt[i][x][4];
else f += cnt[i][x][3];
break;
}
else {
if (T[x] == 1) {
auto w = ST[i].query(1, 0, ST[i].A.size()-1, R[i][x][0], R[i][x][1]);
f += (cnt[i][G[i][x]][0] - cnt[i+1][G[i+1][x]][0]) * w[1] + (cnt[i][G[i][x]][2] - cnt[i+1][G[i+1][x]][2]) * w[0];
cnt[i][G[i][x]][3] += w[0], to[i][G[i+1][x]][0] += w[0];
cnt[i][G[i][x]][4] += w[1], to[i][G[i+1][x]][1] += w[1];
}
else if (T[x] == 0) {
auto w = jump(i, x);
f += (cnt[i][G[i][x]][2] - cnt[i+1][G[i+1][x]][2]) * w + (cnt[i][G[i][x]][4] - to[i][G[i+1][x]][1]);
cnt[i][G[i][x]][3] += w-(T[G[i][x]] == 1), to[i][G[i+1][x]][0] += w-(T[G[i][x]] == 1);
}
else {
auto w = jump(i, x);
f += (cnt[i][G[i][x]][0] - cnt[i+1][G[i+1][x]][0]) * w + (cnt[i][G[i][x]][3] - to[i][G[i+1][x]][0]);
cnt[i][G[i][x]][4] += w-(T[G[i][x]] == 1), to[i][G[i+1][x]][1] += w-(T[G[i][x]] == 1);
}
}
Z[G[i][x]] += cnt[i+1][G[i+1][x]][0] * cnt[i+1][G[i+1][x]][2];
}
}
long long num_tours() {
return f;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |