#include <bits/stdc++.h>
#define ll long long
#define ld long double
#define sp ' '
#define en '\n'
#define smin(a, b) a = min(a, b)
#define smax(a, b) a = max(a, b)
using namespace std;
const int N = 1e5 + 2;
const int inf = 1e9;
int a[N], b[N], c[N], inv[N], n;
vector<int> g[N];
int par[N][17], in[N], out[N], dep[N], tsz;
int st[4 * N];
void Add(int node, int l, int r, int p, int o) {
if (l == r) {
st[node] = o;
return;
}
int mid = l + r >> 1;
if (p <= mid) Add(2 * node, l, mid, p, o);
else Add(2 * node + 1, mid + 1, r, p, o);
st[node] = max(st[2 * node], st[2 * node + 1]);
}
int Get(int node, int l, int r, int ql, int qr) {
if (r < ql || qr < l) return 0;
if (ql <= l && r <= qr) return st[node];
int mid = l + r >> 1;
return max(Get(2 * node, l, mid, ql, qr), Get(2 * node + 1, mid + 1, r, ql, qr));
}
void Dfs(int s, int e) {
in[s] = ++tsz;
par[s][0] = e;
for (int i = 1; i < 17; i++) par[s][i] = par[par[s][i - 1]][i - 1];
Add(1, 1, n, in[s], inv[s]);
for (int u : g[s]) {
dep[u] = dep[s] + 1;
Dfs(u, s);
}
out[s] = tsz;
}
bool Ancestor(int u, int v) {
return in[u] <= in[v] && out[v] <= out[u];
}
int Lca(int u, int v) {
if (Ancestor(u, v)) return u;
if (Ancestor(v, u)) return v;
for (int i = 16; i >= 0; i--) {
if (par[u][i] && !Ancestor(par[u][i], v)) u = par[u][i];
}
return par[u][0];
}
ll sol;
struct Fenwick {
ll bit[N];
void Add(int x, int y) {
for (; x < N; x += x & (-x)) bit[x] += y;
}
ll Get(int x) {
ll ans = 0;
for (; x > 0; x -= x & (-x)) ans += bit[x];
return ans;
}
}f;
void Resi(int s, int x) {
int o = Get(1, 1, n, in[s], out[s]);
int y = b[o];
int lca = Lca(x, y);
sol += f.Get(N - 1) - f.Get(c[y]);
if (lca == x) return;
f.Add(c[y], dep[lca] - dep[s] + 1);
int l = 0, r = g[lca].size() - 1, gde = -1;
while (l <= r) {
int mid = l + r >> 1;
if (in[g[lca][mid]] <= in[x]) {
l = mid + 1;
gde = mid;
}
else r = mid - 1;
}
assert(gde != -1);
Resi(g[lca][gde], x);
f.Add(c[y], dep[s] - dep[lca] - 1);
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
cin >> n;
for (int i = 1; i <= n; i++) cin >> c[i];
for (int i = 1; i < n; i++) {
cin >> a[i] >> b[i];
inv[b[i]] = i;
g[a[i]].push_back(b[i]);
}
Dfs(1, 0);
vector<int> ans;
for (int i = n - 1; i >= 1; i--) {
Add(1, 1, n, in[b[i]], 0);
sol = 0;
if (a[i] == 1) {
ans.push_back(0);
continue;
}
Resi(1, a[i]);
ans.push_back(sol);
}
for (int i = n - 2; i >= 0; i--) cout << ans[i] << en;
return 0;
}
Compilation message
construction.cpp: In function 'void Add(int, int, int, int, int)':
construction.cpp:20:17: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
20 | int mid = l + r >> 1;
| ~~^~~
construction.cpp: In function 'int Get(int, int, int, int, int)':
construction.cpp:28:17: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
28 | int mid = l + r >> 1;
| ~~^~~
construction.cpp: In function 'void Resi(int, int)':
construction.cpp:74:21: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
74 | int mid = l + r >> 1;
| ~~^~~
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
1 ms |
2644 KB |
Output is correct |
2 |
Runtime error |
4 ms |
5204 KB |
Execution killed with signal 11 |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
1 ms |
2644 KB |
Output is correct |
2 |
Runtime error |
4 ms |
5204 KB |
Execution killed with signal 11 |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
1 ms |
2644 KB |
Output is correct |
2 |
Runtime error |
4 ms |
5204 KB |
Execution killed with signal 11 |
3 |
Halted |
0 ms |
0 KB |
- |