#include <bits/stdc++.h>
#define fi first
#define se second
#define all(v) v.begin(), v.end()
using namespace std;
using pii = pair<int, int>;
using ll = long long;
constexpr int K = 17;
constexpr int N = 1e5+5;
int n;
int c[N];
pii edges[N];
int par[N];
vector<int> g[N];
int up[K+1][N];
int sz[N], heavy[N];
void dfs(int v) {
sz[v] = 1;
for (int u : g[v]) {
up[0][u] = v;
for (int i = 1; i <= K; ++i) {
up[i][u] = up[i-1][up[i-1][u]];
}
dfs(u);
if (!heavy[v] || sz[u] > sz[heavy[v]]) {
heavy[v] = u;
}
sz[v] += sz[u];
}
}
int head[N], num[N], timer;
void hld(int v, int h) {
num[v] = ++timer;
head[v] = h;
if (heavy[v]) hld(heavy[v], h);
for (int u : g[v]) {
if (u != heavy[v]) {
hld(u, u);
}
}
}
namespace ds {
int lazy[N*4];
signed active[N*4];
void init() {
memset(lazy, 0, sizeof lazy);
memset(active, 0, sizeof active);
}
void fix(int id, int l, int r) {
if (l != r && active[id]) {
lazy[id*2] = lazy[id*2+1] = lazy[id];
active[id*2] = active[id*2+1] = 1;
active[id] = 0;
}
}
void update(int id, int l, int r, int u, int v, int val) {
fix(id, l, r);
if (v < l || r < u) return;
if (u <= l && r <= v) {
lazy[id] = val;
active[id] = 1;
return;
}
int mid = (l+r)/2;
update(id*2, l, mid, u, v, val);
update(id*2+1, mid+1, r, u, v, val);
}
int get(int id, int l, int r, int pos) {
fix(id, l, r);
if (active[id]) return lazy[id];
int mid = (l+r)/2;
if (pos <= mid) {
return get(id*2, l, mid, pos);
} else {
return get(id*2+1, mid+1, r, pos);
}
}
}
void update(int v, int val) {
for (; v; v = par[head[v]]) {
ds::update(1, 1, n, num[head[v]], num[v], val);
}
}
int get(int v) {
if (v == 0) return -1;
return ds::get(1, 1, n, num[v]);
}
pii upchain(int v) {
int k = 0;
for (int i = K; i >= 0; --i) {
if (up[i][v] && get(up[i][v]) == get(v)) {
k |= 1 << i;
v = up[i][v];
}
}
return {up[0][v], k + 1};
}
namespace fen {
ll bit[N];
void init() {
memset(bit, 0, sizeof bit);
}
void add(int p, ll x) {
for (; p <= n; p += p&-p) bit[p] += x;
}
ll get(int p) {
ll res = 0;
for (; p >= 1; p -= p&-p) res += bit[p];
return res;
}
ll get(int l, int r) {
return get(r) - get(l-1);
}
}
void solve() {
for (int i = 1; i <= n; ++i) {
ds::update(1, 1, n, num[i], num[i], c[i]);
}
for (int i = 1; i < n; ++i) {
auto [a, b] = edges[i];
vector<pii> path;
int u;
for (int v = a; v; v = u) {
pii g = upchain(v);
path.emplace_back(get(v), g.se);
u = g.fi;
}
ll res = 0;
for (auto [val, cnt] : path) {
res += cnt * fen::get(val - 1);
fen::add(val, cnt);
}
cout << res << "\n";
for (auto [val, cnt] : path) {
fen::add(val, -cnt);
}
update(a, c[b]);
}
}
void compress() {
vector<int> comp;
for (int i = 1; i <= n; ++i) comp.push_back(c[i]);
sort(all(comp));
comp.resize(unique(all(comp)) - comp.begin());
for (int i = 1; i <= n; ++i) {
c[i] = lower_bound(all(comp), c[i]) - comp.begin() + 1;
}
}
signed main() {
cin.tie(0)->sync_with_stdio(0);
ds::init();
fen::init();
cin >> n;
for (int i = 1; i <= n; ++i) cin >> c[i];
compress();
for (int i = 1; i < n; ++i) {
int u, v; cin >> u >> v;
edges[i] = {u, v};
par[v] = u;
g[u].push_back(v);
}
dfs(1);
hld(1, 1);
solve();
}
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |