#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1e5 + 7;
int n;
int c[N];
int a[N], b[N];
vector<int> g[N];
int par[N], sz[N], top[N];
int tin[N], tout[N];
vector<int> order;
int t = 0;
void dfs(int u) {
tin[u] = t++;
order.push_back(u);
for (auto v : g[u]) {
if (sz[v] > sz[g[u][0]]) {
swap(v, g[u][0]);
}
}
for (auto v : g[u]) {
top[v] = (v == g[u][0] ? top[u] : v);
dfs(v);
}
tout[u] = t;
}
struct Node {
int mn;
int mx;
Node(int val = 0) {
mn = val;
mx = val;
}
};
Node operator + (Node a, Node b) {
Node c;
c.mn = min(a.mn, b.mn);
c.mx = max(a.mx, b.mx);
return c;
}
Node tree[4 * N];
int color[4 * N];
void build(int v, int tl, int tr) {
color[v] = -1;
if (tr - tl == 1) {
tree[v] = Node(c[order[tl]]);
return;
}
int tm = (tl + tr) / 2;
build(2 * v, tl, tm);
build(2 * v + 1, tm, tr);
tree[v] = tree[2 * v] + tree[2 * v + 1];
}
void push(int v) {
if (color[v] == -1) {
return;
}
tree[2 * v] = Node(color[v]);
tree[2 * v + 1] = Node(color[v]);
color[2 * v] = color[v];
color[2 * v + 1] = color[v];
color[v] = -1;
}
void update(int v, int tl, int tr, int l, int r, int c) {
if (tr <= l || tl >= r) {
return;
}
if (l <= tl && tr <= r) {
tree[v] = Node(c);
color[v] = c;
return;
}
push(v);
int tm = (tl + tr) / 2;
update(2 * v, tl, tm, l, r, c);
update(2 * v + 1, tm, tr, l, r, c);
tree[v] = tree[2 * v] + tree[2 * v + 1];
}
int find_prev(int v, int tl, int tr, int x, int y, int val) {
if (tr <= x || tl >= y || (tree[v].mn == val && tree[v].mx == val)) {
return -1;
}
if (tr - tl == 1) {
return tl;
}
push(v);
int tm = (tl + tr) / 2;
int sol = find_prev(2 * v + 1, tm, tr, x, y, val);
if (sol != -1) {
return sol;
}
return find_prev(2 * v, tl, tm, x, y, val);
}
int get_color(int v, int tl, int tr, int pos) {
if (color[v] != -1) {
return color[v];
}
if (tr - tl == 1) {
return c[order[tl]];
}
int tm = (tl + tr) / 2;
if (pos < tm) {
return get_color(2 * v, tl, tm, pos);
}
return get_color(2 * v + 1, tm, tr, pos);
}
int fen[N];
void add(int i, int val) {
for (++i; i < N; i += i & -i) {
fen[i] += val;
}
}
int get(int i) {
int sol = 0;
for (++i; i >= 1; i -= i & -i) {
sol += fen[i];
}
return sol;
}
ll get_cost(int u) {
vector<pair<int, int>> segs;
while (u != -1) {
int c = get_color(1, 0, n, tin[u]);
int pos = find_prev(1, 0, n, 0, tin[u] + 1, c);
if (pos < tin[top[u]]) {
segs.push_back({c, tin[u] - tin[top[u]] + 1});
u = par[top[u]];
} else {
segs.push_back({c, tin[u] - pos});
u = order[pos];
}
}
ll ans = 0;
for (auto it : segs) {
ans += get(it.first) * 1ll * it.second;
add(it.first, it.second);
}
for (auto it : segs) {
add(it.first, -it.second);
}
return ans;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
cin >> n;
for (int i = 0; i < n; ++i) {
cin >> c[i];
}
vector<int> sorted;
for (int i = 0; i < n; ++i) {
sorted.push_back(c[i]);
}
sort(sorted.begin(), sorted.end());
for (int i = 0; i < n; ++i) {
c[i] = lower_bound(sorted.begin(), sorted.end(), c[i]) - sorted.begin();
}
for (int i = 0; i < n - 1; ++i) {
cin >> a[i] >> b[i];
--a[i];
--b[i];
}
par[0] = -1;
for (int i = 0; i < n - 1; ++i) {
par[b[i]] = a[i];
g[a[i]].push_back(b[i]);
}
for (int i = 0; i < n; ++i) {
sz[i] = 1;
}
for (int i = n - 1; i >= 0; --i) {
sz[a[i]] += sz[b[i]];
}
dfs(0);
build(1, 0, n);
for (int i = 0; i < n - 1; ++i) {
cout << get_cost(a[i]) << "\n";
for (int u = b[i]; u != -1; u = par[top[u]]) {
update(1, 0, n, tin[top[u]], tin[u] + 1, c[b[i]]);
}
}
return 0;
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
2 ms |
10584 KB |
Output is correct |
2 |
Incorrect |
2 ms |
10588 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
2 ms |
10584 KB |
Output is correct |
2 |
Incorrect |
2 ms |
10588 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
2 ms |
10584 KB |
Output is correct |
2 |
Incorrect |
2 ms |
10588 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |