#include <bits/stdc++.h>
using namespace std;
const int N = 100'000 + 10;
int n;
int c[N];
vector<int> ad[N];
pair<int, int> edge[N];
int sz[N];
void dfs(int u, int p = -1) {
for (const auto& v : ad[u]) {
if (v == p) continue;
dfs(v, u);
sz[u] += sz[v] + 1;
}
}
int par[N], head[N], st[N], node[N], ed[N], num;
void hld(int u, int p = -1) {
st[u] = ++num; node[num] = u;
if (!head[u]) head[u] = u;
sort(ad[u].begin(), ad[u].end(), [&](const auto& a, const auto& b) {
return sz[a] > sz[b];
});
bool goHeavy = false;
for (const auto& v : ad[u]) {
if (v == p) continue;
if (!goHeavy) goHeavy = true, head[v] = head[u];
par[v] = u;
hld(v, u);
}
ed[u] = num;
}
inline bool anc(int u, int v) { return st[u] <= st[v] && ed[v] <= ed[u]; }
vector<pair<int, int>> get(int u, int v) {
vector<pair<int, int>> ret;
for (; !anc(head[u], head[v]); u = par[head[u]]) ret.push_back({st[head[u]], st[u]});
for (; head[v] != head[u]; v = par[head[v]]) ret.push_back({st[head[v]], st[v]});
ret.push_back({min(st[u], st[v]), max(st[u], st[v])});
return ret;
}
namespace BIT {
vector<pair<int, int>> save;
int bit[N];
inline void upd(int i, int x) {
if (x > 0) save.push_back({i, x});
for (; i <= n; i += i & -i) bit[i] += x;
}
inline int que(int i) {
int ret = 0;
for (; i; i -= i & -i) ret += bit[i];
return ret;
}
inline void clear() {
for (const auto& [i, x] : save) upd(i, -x);
save.clear();
}
}
vector<pair<int, int>> vt[N];
long long cal(int u, int w) {
long long ret = 0;
for (const auto& [l, r] : get(u, 1)) {
vector<pair<int, int>> proc;
int length = 0;
while (vt[l].size()) {
auto [cnt, value] = vt[l].back(); vt[l].pop_back();
length += cnt;
if (length >= r - l + 1) {
if (length > r - l + 1) {
vt[l].push_back({length - (r - l + 1), value});
proc.push_back({(r - l + 1) - length + cnt, value});
}
break;
}
proc.push_back({cnt, value});
}
vt[l].push_back({r - l + 1, w});
reverse(proc.begin(), proc.end());
for (const auto& [cnt, value] : proc) {
ret += BIT::que(value - 1);
BIT::upd(value, cnt);
}
}
BIT::clear();
return ret;
}
int32_t main() {
cin.tie(0)->sync_with_stdio(0);
cin >> n;
for (int i = 1; i <= n; ++i) cin >> c[i];
for (int i = 1; i < n; ++i) {
int u, v; cin >> u >> v;
ad[u].push_back(v);
ad[v].push_back(u);
edge[i] = {u, v};
}
{ // discrete
vector<int> rrh(c + 1, c + n + 1);
sort(rrh.begin(), rrh.end());
rrh.erase(unique(rrh.begin(), rrh.end()), rrh.end());
for (int i = 1; i <= n; ++i)
c[i] = upper_bound(rrh.begin(), rrh.end(), c[i]) - rrh.begin();
}
dfs(1);
hld(par[1] = 1);
for (int i = 1; i < n; ++i) {
const auto& [u, v] = edge[i];
cout << cal(v, c[v]) << "\n";
}
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |