#include "joitour.h"
#include <bits/stdc++.h>
using i64 = long long;
#ifdef DEBUG
#include "/home/ahmetalp/Desktop/Workplace/debug.h"
#else
#define debug(...) void(23)
#endif
namespace {
int N;
std::vector<int> F;
std::vector<int> U;
std::vector<int> V;
std::vector<std::vector<int>> adj;
struct fenwick {
int n;
int all = 0;
std::vector<int> tree;
fenwick() {}
fenwick(int n_) : n(n_), tree(n + 1) {}
void init(int n_) {
n = n_;
tree.assign(n + 1, 0);
}
void modify(int p, int x) {
all += x;
for (p += 1; p <= n; p += p & -p) {
tree[p] += x;
}
}
void modify(int l, int r, int x) {
modify(l, +x);
modify(r, -x);
}
int get(int p) {
int res = 0;
for (p += 1; p; p -= p & -p) {
res += tree[p];
}
return res;
}
int get(int l, int r) {
// assert(l <= r);
return get(r - 1) - get(l - 1);
}
int oth(int l, int r) {
return all - get(l, r);
}
};
i64 ans = 0;
struct DS {
i64 cnt10 = 0;
i64 cnt12 = 0;
i64 cnt02 = 0;
std::map<int, i64> top_cnt12;
std::map<int, i64> top_cnt10;
fenwick fen0;
fenwick fen1;
fenwick fen2;
int n = 0;
int tim = 0;
int r = -1;
int col = -1;
std::map<int, int> tin;
std::map<int, int> tout;
std::map<int, int> top;
DS() {}
void init(int n_, int r_) {
n = n_;
fen0.init(n);
fen1.init(n);
fen2.init(n);
r = r_;
}
void open(int v) {
// assert(!tin.contains(v));
// tin[v] = tim++;
tim++;
}
void close(int v) {
// assert(!tout.contains(v));
// tout[v] = tim;
}
void set(int v, int x, int tin_v, int tout_v, int top_v, int tin_top_v, int tout_top_v) {
if (x < 0) {
// closes
x = -x - 1;
if (v == r) {
// root
if (x == 0) {
ans -= cnt12;
} else if (x == 1) {
ans -= cnt02;
} else {
ans -= cnt10;
}
col = -1;
} else {
if (x == 0) {
if (col == 2) {
ans -= fen1.get(tin_v);
}
if (col == 1) {
ans -= fen2.oth(tin_top_v, tout_top_v);
}
ans -= 1LL * fen1.get(tin_v) * fen2.oth(tin_top_v, tout_top_v);
ans -= (cnt12 - top_cnt12[top_v]);
cnt02 -= fen2.oth(tin_top_v, tout_top_v);
fen0.modify(tin_v, -1);
cnt10 -= fen1.get(tin_v);
top_cnt10[top_v] -= fen1.get(tin_v);
} else if (x == 1) {
if (col == 0) {
ans -= fen2.get(tin_v, tout_v);
}
if (col == 2) {
ans -= fen0.get(tin_v, tout_v);
}
ans -= 1LL * fen0.oth(tin_top_v, tout_top_v) * fen2.get(tin_v, tout_v);
ans -= 1LL * fen2.oth(tin_top_v, tout_top_v) * fen0.get(tin_v, tout_v);
fen1.modify(tin_v, tout_v, -1);
cnt10 -= fen0.get(tin_v, tout_v);
cnt12 -= fen2.get(tin_v, tout_v);
top_cnt10[top_v] -= fen0.get(tin_v, tout_v);
top_cnt12[top_v] -= fen2.get(tin_v, tout_v);
} else {
if (col == 0) {
ans -= fen1.get(tin_v);
}
if (col == 1) {
ans -= fen0.oth(tin_top_v, tout_top_v);
}
ans -= 1LL * fen1.get(tin_v) * fen0.oth(tin_top_v, tout_top_v);
ans -= (cnt10 - top_cnt10[top_v]);
cnt02 -= fen0.oth(tin_top_v, tout_top_v);
fen2.modify(tin_v, -1);
cnt12 -= fen1.get(tin_v);
top_cnt12[top_v] -= fen1.get(tin_v);
}
}
} else {
if (v == r) {
// root
if (x == 0) {
ans += cnt12;
} else if (x == 1) {
ans += cnt02;
} else {
ans += cnt10;
}
col = x;
} else {
if (x == 0) {
if (col == 2) {
ans += fen1.get(tin_v);
}
if (col == 1) {
ans += fen2.oth(tin_top_v, tout_top_v);
}
ans += 1LL * fen1.get(tin_v) * fen2.oth(tin_top_v, tout_top_v);
ans += (cnt12 - top_cnt12[top_v]);
cnt02 += fen2.oth(tin_top_v, tout_top_v);
fen0.modify(tin_v, +1);
cnt10 += fen1.get(tin_v);
top_cnt10[top_v] += fen1.get(tin_v);
} else if (x == 1) {
if (col == 0) {
ans += fen2.get(tin_v, tout_v);
}
if (col == 2) {
ans += fen0.get(tin_v, tout_v);
}
ans += 1LL * fen0.oth(tin_top_v, tout_top_v) * fen2.get(tin_v, tout_v);
ans += 1LL * fen2.oth(tin_top_v, tout_top_v) * fen0.get(tin_v, tout_v);
fen1.modify(tin_v, tout_v, +1);
cnt10 += fen0.get(tin_v, tout_v);
cnt12 += fen2.get(tin_v, tout_v);
top_cnt10[top_v] += fen0.get(tin_v, tout_v);
top_cnt12[top_v] += fen2.get(tin_v, tout_v);
} else {
if (col == 0) {
ans += fen1.get(tin_v);
}
if (col == 1) {
ans += fen0.oth(tin_top_v, tout_top_v);
}
ans += 1LL * fen1.get(tin_v) * fen0.oth(tin_top_v, tout_top_v);
ans += (cnt10 - top_cnt10[top_v]);
cnt02 += fen0.oth(tin_top_v, tout_top_v);
fen2.modify(tin_v, +1);
cnt12 += fen1.get(tin_v);
top_cnt12[top_v] += fen1.get(tin_v);
}
}
}
}
};
std::vector<DS> centros;
std::vector<int> act;
std::vector<int> siz;
std::vector<std::vector<std::array<int, 6>>> pars;
void calc_sizes(int v, int pr) {
siz[v] = 1;
for (auto u : adj[v]) {
if (!act[u] || u == pr) {
continue;
}
calc_sizes(u, v);
siz[v] += siz[u];
}
}
int get_centroid(int v, int pr, int tot) {
for (auto u : adj[v]) {
if (!act[u] || u == pr) {
continue;
}
if (siz[u] * 2 > tot) {
return get_centroid(u, v, tot);
}
}
return v;
}
void insert(int r, int v, int pr, int top_v, int tin_top_v, int tout_top_v) {
int tim = centros[r].tim;
debug(r, v, tim, tim + siz[v], top_v, tin_top_v, tout_top_v);
pars[v].push_back({r, tim, tim + siz[v], top_v, tin_top_v, tout_top_v});
centros[r].open(v);
for (auto u : adj[v]) {
if (!act[u] || u == pr) {
continue;
}
int new_top_v;
int new_tin_top_v;
int new_tout_top_v;
if (v == r) {
new_top_v = u;
new_tin_top_v = centros[r].tim;
new_tout_top_v = centros[r].tim + siz[u];
// centros[r].top[u] = u;
} else {
new_top_v = top_v;
new_tin_top_v = tin_top_v;
new_tout_top_v = tout_top_v;
// centros[r].top[u] = centros[r].top[v];
}
insert(r, u, v, new_top_v, new_tin_top_v, new_tout_top_v);
}
centros[r].close(v);
}
void dnq(int r) {
calc_sizes(r, r);
int n = siz[r];
r = get_centroid(r, r, n);
calc_sizes(r, r);
// debug(r, n);
centros[r].init(n, r);
insert(r, r, -1, -1, -1, -1);
act[r] = 0;
for (auto u : adj[r]) {
if (!act[u]) {
continue;
}
dnq(u);
}
}
};
void init(int N, std::vector<int> F, std::vector<int> U, std::vector<int> V,
int Q) {
::N = N;
::F = F;
::U = U;
::V = V;
adj.assign(N, {});
for (int i = 0; i < N - 1; ++i) {
adj[U[i]].emplace_back(V[i]);
adj[V[i]].emplace_back(U[i]);
}
centros.assign(N, {});
act.assign(N, 1);
siz.assign(N, 1);
pars.assign(N, {});
dnq(0);
for (int v = 0; v < N; ++v) {
for (auto[r, tin_v, tout_v, top_v, tin_top_v, tout_top_v] : pars[v]) {
centros[r].set(v, F[v], tin_v, tout_v, top_v, tin_top_v, tout_top_v);
}
}
}
void change(int X, int Y) {
for (auto[r, tin_v, tout_v, top_v, tin_top_v, tout_top_v] : pars[X]) {
centros[r].set(X, -F[X] - 1, tin_v, tout_v, top_v, tin_top_v, tout_top_v);
}
F[X] = Y;
for (auto[r, tin_v, tout_v, top_v, tin_top_v, tout_top_v] : pars[X]) {
centros[r].set(X, F[X], tin_v, tout_v, top_v, tin_top_v, tout_top_v);
}
}
i64 num_tours() {
return ans;
}