#include <bits/stdc++.h>
using namespace std;
#define int long long
const int inf = 1e18;
const int N = 1e5 + 10;
const int LG = 17;
int n, lift[N][LG], dep[N], root[N], sz[N], tin[N], par[N], c[N], fen[N], timer = 0, paint[N];
vector<int> adj[N];
pair<int, int> e[N];
void upd(int idx, int k) {
for (; idx < N; idx += (idx & -idx)) fen[idx] += k;
}
int query(int idx) {
int res = 0;
for (; idx; idx -= (idx & -idx)) res += fen[idx];
return res;
}
void dfs(int u, int p) {
sz[u] = 1;
for (int i = 1; i < LG; i++) lift[u][i] = lift[lift[u][i-1]][i-1];
for (auto& v : adj[u]) {
if (v == p) continue;
lift[v][0] = u;
dep[v] = dep[u] + 1;
dfs(v, u);
sz[u] += sz[v];
if (sz[v] > sz[adj[u][0]]) swap(v, adj[u][0]);
}
}
void hld(int u, int p) {
tin[u] = ++timer;
for (int v : adj[u]) {
if (v == p) continue;
root[v] = (v == adj[u][0] ? root[u] : v);
hld(v, u);
}
}
struct SegTree {
int size = 1;
vector<int> seg, lazy;
void init(int n) {
while (size < n) size *= 2;
seg.assign(2 * size + 10, -inf);
lazy.assign(2 * size + 10, -inf);
}
void push(int id) {
if (lazy[id] == -inf) return;
seg[id] = lazy[id];
for (int i = 0; i < 2; i++) lazy[id*2+i] = lazy[id];
lazy[id] = -inf;
}
void update(int ql, int qr, int val, int l, int r, int id) {
push(id);
if (qr < l || r < ql) return;
if (ql <= l && r <= qr) {
lazy[id] = val;
push(id);
return;
}
int mid = (l + r) / 2;
update(ql, qr, val, l, mid, id * 2);
update(ql, qr, val, mid + 1, r, id * 2 + 1);
}
int query(int pos, int l, int r, int id) {
push(id);
if (l == r) return seg[id];
int mid = (l + r) / 2;
if (pos <= mid) return query(pos, l, mid, id * 2);
return query(pos, mid + 1, r, id * 2 + 1);
}
} ST;
void update_path(int x, int y, int val) {
for (; root[x] != root[y]; y = par[root[y]]) {
if (dep[root[x]] > dep[root[y]]) swap(x, y);
// cout << "update " << tin[root[y]] << ' ' << tin[y] << ' ' << val << '\n';
ST.update(tin[root[y]], tin[y], val, 1, n, 1);
}
if (dep[x] > dep[y]) swap(x, y);
// cout << "update " << tin[x] << ' ' << tin[y] << '\n';
ST.update(tin[x], tin[y], val, 1, n, 1);
}
int jump(int sta, int dist) {
for (int i = LG - 1; i >= 0; i--) if (dist & (1 << i)) sta = lift[sta][i];
return sta;
}
int32_t main() {
ios::sync_with_stdio(0); cin.tie(0);
cin >> n;
vector<int> disc;
for (int i = 1; i <= n; i++) {
cin >> c[i];
disc.push_back(c[i]);
}
sort(disc.begin(), disc.end());
disc.erase(unique(disc.begin(), disc.end()), disc.end());
for (int i = 1; i <= n; i++) {
c[i] = lower_bound(disc.begin(), disc.end(), c[i]) - disc.begin() + 1;
}
for (int i = 1; i < n; i++) {
int u, v;
cin >> u >> v;
paint[i] = c[v];
adj[u].push_back(v);
par[v] = u;
e[i] = {u, v};
}
// for (int i = 1; i < n; i++) cout << paint[i] << '\n';
lift[1][0] = 1, root[1] = 1;
dfs(1, -1);
hld(1, -1);
ST.init(n + 1);
for (int i = 1; i < n; i++) {
int node = e[i].first;
vector<pair<int, int>> vec;
int tmp = ST.query(tin[node], 1, n, 1);
// cout << "yo\n";
int rt = i-1;
if (!rt) rt = -inf;
while (node >= 1) {
// cout << "node: " << node << '\n';
if (tmp == rt) {
vec.push_back({tmp, dep[node] + 1});
break;
}
int l = 1, r = dep[node];
// find min x s.t. col[x] !=
while (l < r) {
int mid = (l + r) / 2;
int tar = jump(node, mid);
if (ST.query(tin[tar], 1, n, 1) != tmp) r = mid;
else l = mid + 1;
}
vec.push_back({tmp, l});
node = jump(node, l);
tmp = ST.query(tin[node], 1, n, 1);
}
reverse(vec.begin(), vec.end());
int inv = 0, tot = 0;
if (i > 1) {
for (auto& [x, y] : vec) {
// cout << "pair: " << x << " " << y << '\n';
inv += tot - query(paint[x]);
upd(paint[x], y);
tot += y;
}
for (auto& [x, y] : vec) upd(paint[x], -y);
}
// cout << "test\n";
update_path(1, e[i].second, i);
cout << inv << '\n';
}
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
2 ms |
10576 KB |
Output is correct |
2 |
Correct |
2 ms |
10576 KB |
Output is correct |
3 |
Runtime error |
22 ms |
21328 KB |
Execution killed with signal 6 |
4 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
2 ms |
10576 KB |
Output is correct |
2 |
Correct |
2 ms |
10576 KB |
Output is correct |
3 |
Runtime error |
22 ms |
21328 KB |
Execution killed with signal 6 |
4 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
2 ms |
10576 KB |
Output is correct |
2 |
Correct |
2 ms |
10576 KB |
Output is correct |
3 |
Runtime error |
22 ms |
21328 KB |
Execution killed with signal 6 |
4 |
Halted |
0 ms |
0 KB |
- |