#ifndef LOCAL
#pragma GCC Optimize("O3,Ofast,unroll-loops")
#pragma GCC Target("bmi,bmi2,avx,avx2")
#endif
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
using ld = long double;
#define f first
#define s second
#define mp make_pair
#define pb push_back
#define pii pair<int, int>
#define all(x) (x).begin(), (x).end()
#define rall(x) (x).rbegin() (x).rend()
#ifndef LOCAL
#define endl "\n"
#endif
mt19937 rnd(11);
const int LOG_N = 18;
struct F {
vector<int> t;
F() = default;
int get(int r) {
int ans = 0;
for (; r >= 0; r = (r&(r + 1)) - 1) {
ans += t[r];
}
return ans;
}
void upd(int i, int x) {
for (; i < t.size(); i = (i|(i + 1))) {
t[i] += x;
}
}
};
struct FL {
vector<ll> t;
FL() = default;
ll get(int r) {
ll ans = 0;
for (; r >= 0; r = (r&(r + 1)) - 1) {
ans += t[r];
}
return ans;
}
void upd(int i, int x) {
for (; i < t.size(); i = (i|(i + 1))) {
t[i] += x;
}
}
};
vector<vector<int>> graph;
vector<int> CP[LOG_N], lvl, sz, f, tin, tout, binup[LOG_N], tin1[LOG_N], tout1[LOG_N], fst[LOG_N], act;
vector<ll> zerotwo;
vector<F> T[6];
vector<FL> two, zero;
ll ans = 0;
int Tm = 0;
void dfs(int v, int prev = 0) {
binup[0][v] = prev;
tin[v] = Tm++;
for (auto &u : graph[v]) {
if (tin[u] == -1) {
dfs(u, v);
}
}
tout[v] = Tm;
}
bool inside(int a, int b) {
return tin[a] <= tin[b] && tout[b] <= tout[b];
}
int lca(int a, int b) {
if (inside(a, b)) {
return a;
}
for (int lg = LOG_N - 1; lg >= 0; --lg) {
if (!inside(binup[lg][a], b)) {
a = binup[lg][a];
}
}
return binup[0][a];
}
int calcUp(int a, int t) {
int ans = 0;
while (a != 0) {
ans += (f[a] == t);
a = binup[0][a];
}
return ans + (f[a] == t);
}
int calcWay(int a, int b, int t) {
int lc = lca(a, b);
int ans = calcUp(a, t) + calcUp(b, t) - calcUp(lc, t) * 2;
if (f[lc] == t) {
++ans;
}
return ans;
}
ll sumup(int v, int lg, int t) {
return T[t][CP[lg][v]].get(tout1[lg][v]) - T[3 + t][CP[lg][v]].get(tout1[lg][v]);
}
ll sumsub(int v, int lg, int t) {
return T[t][CP[lg][v]].get(tout1[lg][v] - 1) - T[t][CP[lg][v]].get(tin1[lg][v] - 1);
}
ll calc(int v) {
ll ans = 0;
for (int lg = lvl[v] - 1; lg >= 0; --lg) {
if (f[v] == 0) {
ll onenum = sumup(v, lg, 1);
ll twonum = sumsub(CP[lg][v], lg, 2) - sumsub(fst[lg][v], lg, 2);
ans += onenum * twonum;
ans += two[CP[lg][v]].get(tout1[lg][CP[lg][v]] - 1);
ans -= two[CP[lg][v]].get(tout1[lg][fst[lg][v]] - 1) - two[CP[lg][v]].get(tin1[lg][fst[lg][v]] - 1);
} else if (f[v] == 1) {
ans += sumsub(v, lg, 0) * (sumsub(CP[lg][v], lg, 2) - sumsub(fst[lg][v], lg, 2));
ans += sumsub(v, lg, 2) * (sumsub(CP[lg][v], lg, 0) - sumsub(fst[lg][v], lg, 0));
} else {
ll onenum = sumup(v, lg, 1);
ll zeronum = sumsub(CP[lg][v], lg, 0) - sumsub(fst[lg][v], lg, 0);
ans += onenum * zeronum;
ans += zero[CP[lg][v]].get(tout1[lg][CP[lg][v]] - 1);
ans -= zero[CP[lg][v]].get(tout1[lg][fst[lg][v]] - 1) - zero[CP[lg][v]].get(tin1[lg][fst[lg][v]] - 1);
}
}
if (f[v] == 0) {
ans += two[v].get(tout1[lvl[v]][v] - 1);
} else if (f[v] == 2) {
ans += zero[v].get(tout1[lvl[v]][v] - 1);
} else {
ans += zerotwo[v];
}
return ans;
}
void del(int v) {
ans -= calc(v);
for (int lg = lvl[v]; lg >= 0; --lg) {
T[f[v]][CP[lg][v]].upd(tin1[lg][v], -1);
T[f[v] + 3][CP[lg][v]].upd(tout1[lg][v], -1);
if (lg != lvl[v]) {
if (f[v] == 0) {
zerotwo[CP[lg][v]] -= sumsub(CP[lg][v], lg, 2) - sumsub(fst[lg][v], lg, 2) - (f[CP[lg][v]] == 2 && act[CP[lg][v]]);
zero[CP[lg][v]].upd(tin1[lg][v], -(sumup(v, lg, 1) - (f[CP[lg][v]] == 1 && act[CP[lg][v]])));
} else if (f[v] == 2) {
zerotwo[CP[lg][v]] -= sumsub(CP[lg][v], lg, 0) - sumsub(fst[lg][v], lg, 0) - (f[CP[lg][v]] == 0 && act[CP[lg][v]]);
two[CP[lg][v]].upd(tin1[lg][v], -(sumup(v, lg, 1) - (f[CP[lg][v]] == 1 && act[CP[lg][v]])));
} else {
zero[CP[lg][v]].upd(tin1[lg][v], -sumsub(v, lg, 0));
two[CP[lg][v]].upd(tin1[lg][v], -sumsub(v, lg, 2));
}
}
}
act[v] = false;
}
void add(int v) {
for (int lg = lvl[v]; lg >= 0; --lg) {
T[f[v]][CP[lg][v]].upd(tin1[lg][v], 1);
T[f[v] + 3][CP[lg][v]].upd(tout1[lg][v], 1);
if (lg != lvl[v]) {
if (f[v] == 0) {
zerotwo[CP[lg][v]] += sumsub(CP[lg][v], lg, 2) - sumsub(fst[lg][v], lg, 2) - (f[CP[lg][v]] == 2 && act[CP[lg][v]]);
zero[CP[lg][v]].upd(tin1[lg][v], (sumup(v, lg, 1) - (f[CP[lg][v]] == 1 && act[CP[lg][v]])));
} else if (f[v] == 2) {
zerotwo[CP[lg][v]] += sumsub(CP[lg][v], lg, 0) - sumsub(fst[lg][v], lg, 0) - (f[CP[lg][v]] == 0 && act[CP[lg][v]]);
two[CP[lg][v]].upd(tin1[lg][v], (sumup(v, lg, 1) - (f[CP[lg][v]] == 1 && act[CP[lg][v]])));
} else {
zero[CP[lg][v]].upd(tin1[lg][v], sumsub(v, lg, 0));
two[CP[lg][v]].upd(tin1[lg][v], sumsub(v, lg, 2));
}
}
}
act[v] = true;
ans += calc(v);
}
void init(int n, vector<int> F, vector<int> u, vector<int> v, int q) {
f = F;
zerotwo.resize(n);
act.resize(n);
graph.resize(n);
lvl.assign(n, -1);
binup[0].resize(n);
tin.assign(n, -1);
tout.resize(n);
for (int i = 0; i < 6; ++i) {
T[i].resize(n);
}
two.resize(n);
zero.resize(n);
for (int i = 0; i < n - 1; ++i) {
graph[u[i]].pb(v[i]);
graph[v[i]].pb(u[i]);
}
dfs(0);
for (int l = 1; l < LOG_N; ++l) {
binup[l].resize(n);
for (int i = 0; i < n; ++i) {
binup[l][i] = binup[l - 1][binup[l - 1][i]];
}
}
int lg = 0;
auto calcSz = [&](int v, int prev, auto &&self) -> void {
sz[v] = 1;
for (auto &u : graph[v]) {
if (lvl[u] == -1 && u != prev) {
self(u, v, self);
sz[v] += sz[u];
}
}
};
auto findCenter = [&](int v, int prev, int Tsz, auto &&self) -> int {
for (auto &u : graph[v]) {
if (lvl[u] == -1 && u != prev && sz[u] * 2 > Tsz) {
return self(u, v, Tsz, self);
}
}
return v;
};
int Tm1;
auto paint = [&](int v, int prev, int center, auto &&self) -> void {
if (prev == center) {
fst[lg][v] = v;
} else {
fst[lg][v] = fst[lg][prev];
}
tin1[lg][v] = Tm1++;
CP[lg][v] = center;
for (auto &u : graph[v]) {
if (lvl[u] == -1 && u != prev) {
self(u, v, center, self);
}
}
tout1[lg][v] = Tm1++;
};
for (; lg < LOG_N; ++lg) {
CP[lg].resize(n);
fst[lg].resize(n);
tin1[lg].resize(n);
tout1[lg].resize(n);
sz.assign(n, -1);
for (int i = 0; i < n; ++i) {
if (lvl[i] == -1 && sz[i] == -1) {
Tm1 = 0;
calcSz(i, i, calcSz);
int center = findCenter(i, i, sz[i], findCenter);
lvl[center] = lg;
calcSz(center, center, calcSz);
for (int j = 0; j < 6; ++j) {
T[j][center].t.resize(sz[center] * 2);
}
two[center].t.resize(sz[center] * 2);
zero[center].t.resize(sz[center] * 2);
paint(center, center, center, paint);
}
}
}
for (int i = 0; i < n; ++i) {
add(i);
}
}
void change(int v, int x) {
del(v);
f[v] = x;
add(v);
}
long long num_tours() {
return ans;
}
#ifdef LOCAL
#include <cassert>
#include <cstdio>
int main() {
freopen("in.txt", "r", stdin);
freopen("out.txt", "w", stdout);
int N;
assert(scanf("%d", &N) == 1);
std::vector<int> F(N);
for (int i = 0; i < N; i++) {
assert(scanf("%d", &F[i]) == 1);
}
std::vector<int> U(N - 1), V(N - 1);
for (int j = 0; j < N - 1; j++) {
assert(scanf("%d %d", &U[j], &V[j]) == 2);
}
int Q;
assert(scanf("%d", &Q) == 1);
init(N, F, U, V, Q);
printf("%lld\n", num_tours());
fflush(stdout);
for (int k = 0; k < Q; k++) {
int X, Y;
assert(scanf("%d %d", &X, &Y) == 2);
change(X, Y);
printf("%lld\n", num_tours());
fflush(stdout);
}
}
#endif
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |