#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAXN = 1e5 + 25;
const int B = 18;
int n, C[MAXN];
array <int, 2> ee[MAXN];
int p[MAXN], in[MAXN], tt, dep[MAXN], v[MAXN], dp[B][MAXN], sze[MAXN], nxt[MAXN];
vector <int> adj[MAXN];
void dfs1 (int pos) {
sze[pos] = 1;
for (auto &j : adj[pos]) {
dfs1(j);
sze[pos] += sze[j];
if (sze[j] > sze[adj[pos][0]]) {
swap(j, adj[pos][0]);
}
}
}
void dfs2 (int pos) {
in[pos] = ++tt;
for (auto j : adj[pos]) {
nxt[j] = (j == adj[pos][0] ? nxt[pos] : j);
dfs2(j);
}
}
#define mid ((l + r) >> 1)
#define tl (node << 1)
#define tr (node << 1 | 1)
struct SegmentTree {
int tree[MAXN << 2], lazy[MAXN << 2];
void prop (int l, int r, int node) {
if (lazy[node] == 0) {
return;
}
if (l != r) {
lazy[tl] = lazy[node];
lazy[tr] = lazy[node];
}
tree[node] = lazy[node];
lazy[node] = 0;
}
void update (int l, int r, int a, int b, int c, int node) {
prop(l, r, node);
if (l > b || r < a) return;
if (l >= a && r <= b) {
lazy[node] = c;
prop(l, r, node);
return;
}
update(l, mid, a, b, c, tl);
update(mid + 1, r, a, b, c, tr);
}
int get (int l, int r, int a, int node) {
prop(l, r, node);
if (l == r) {
return tree[node];
}
if (a <= mid) {
return get(l, mid, a, tl);
} else {
return get(mid + 1, r, a, tr);
}
}
} cur;
int highest[MAXN];
struct BIT {
int tree[MAXN];
void add (int x, int y) {
for (; x < MAXN; x += x & (-x)) {
tree[x] += y;
}
}
int get (int x) {
int sum = 0;
for (; x > 0; x -= x & (-x)) {
sum += tree[x];
}
return sum;
}
} cur2;
vector <int> cc;
ll inversions (vector <pair <int, int>> ee) {
ll ret = 0;
int p = 0;
for (auto [x, y] : ee) {
int t = lower_bound(cc.begin(), cc.end(), x) - cc.begin();
int z = cur2.get(t);
z = p - z;
ret += (ll)z * (ll)y;
p += y;
cur2.add(t, y);
}
for (auto [x, y] : ee) {
int t = lower_bound(cc.begin(), cc.end(), x) - cc.begin();
cur2.add(t, -y);
}
return ret;
}
void solve () {
cin >> n;
for (int i = 1; i <= n; i++) {
cin >> C[i];
}
cc.push_back(0);
for (int i = 1; i <= n; i++) {
cc.push_back(C[i]);
}
sort(cc.begin(), cc.end());
for (int i = 1; i < n; i++) {
int x, y; cin >> x >> y;
ee[i] = {x, y};
p[y] = x;
dep[y] = 1 + dep[x];
adj[x].push_back(y);
}
dfs1(1);
nxt[1] = 1;
dfs2(1);
for (int i = 1; i <= n; i++) {
dp[0][i] = p[i];
}
for (int j = 1; j < B; j++) {
for (int i = 1; i <= n; i++) {
dp[j][i] = dp[j - 1][dp[j - 1][i]];
}
}
cur.update(1, n, in[1], in[1], 1, 1);
highest[1] = 1;
for (int i = 1; i < n; i++) {
int x = ee[i][0], y = ee[i][1];
vector <pair <int, int>> gg;
while (true) {
int c = cur.get(1, n, in[x], 1);
int z = highest[c];
gg.push_back({C[c], dep[x] - dep[z] + 1});
if (x == c) {
highest[c] = -1;
} else {
int g = c;
for (int i = B - 1; i >= 0; i--) {
if (dp[i][g] != 0 && dep[dp[i][g]] > dep[x]) {
g = dp[i][g];
}
}
highest[c] = g;
}
while (nxt[x] != nxt[z]) {
cur.update(1, n, in[nxt[x]], in[x], y, 1);
x = p[nxt[x]];
}
cur.update(1, n, in[z], in[x], y, 1);
if (z == 1) {
break;
}
x = p[z];
}
reverse(gg.begin(), gg.end());
cout << inversions(gg) << '\n';
cur.update(1, n, in[y], in[y], y, 1);
highest[y] = 1;
}
}
signed main () {
ios::sync_with_stdio(0); cin.tie(0);
int tc = 1; //cin >> tc;
while (tc--) solve();
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |