#include <bits/stdc++.h>
using namespace std;
#define MAXN 100005
#define fi first
#define se second
#define rb(x) ((x) & (-(x)))
typedef long long lint;
typedef pair<int, int> pii;
int C[MAXN], A[MAXN], B[MAXN];
vector<int> ed[MAXN];
int par[20][MAXN], dep[MAXN], dfsord[MAXN], dfsr[MAXN], dfsn;
int cs[MAXN], ch[MAXN];
int seg[4 * MAXN], bit[MAXN];
//int xxxx = 0;
int anc(int x, int d) {
for(int i = 0; i < 20; i++) if((dep[x] - d) & (1 << i)) x = par[i][x];
return x;
}
void dfs(int x) {
dfsord[x] = ++dfsn;
for(auto a : ed[x]) {
par[0][a] = x;
dep[a] = dep[x] + 1;
dfs(a);
}
dfsr[x] = dfsn;
}
void updseg(int idx, int l, int r, int x, int y) {
if(l == r) seg[idx] = y;
else {
int m = (l + r) / 2;
if(x <= m) updseg(idx * 2, l, m, x, y);
else updseg(idx * 2 + 1, m + 1, r, x, y);
seg[idx] = max(seg[idx * 2], seg[idx * 2 + 1]);
}
}
int gseg(int idx, int l, int r, int x, int y) {
//if(idx == 1 && ++xxxx < 100) printf("gseg(x = %d, y = %d)\n", x, y);
if(x <= l && r <= y) return seg[idx];
if(r < x || y < l) return -1;
int m = (l + r) / 2;
return max(gseg(idx * 2, l, m, x, y), gseg(idx * 2 + 1, m + 1, r, x, y));
}
void updbit(int x, int y) { for(; x < MAXN; x += rb(x)) bit[x] += y; }
int gbit(int x) { int ans = 0; for(; x > 0; x -= rb(x)) ans += bit[x]; return ans; }
int main() {
int N;
scanf("%d", &N);
for(int i = 1; i <= N; i++) scanf("%d", C + i);
for(int i = 1; i < N; i++) scanf("%d%d", A + i, B + i);
for(int i = 1; i < N; i++) ed[A[i]].push_back(B[i]);
dep[1] = 1;
dfs(1);
for(int i = 1; i < 20; i++) for(int j = 1; j <= N; j++) par[i][j] = par[i - 1][par[i - 1][j]];
for(int i = 1; i <= N; i++) cs[i - 1] = C[i];
sort(cs, cs + N);
for(int i = 1; i <= N; i++) C[i] = lower_bound(cs, cs + N, C[i]) - cs + 1;
for(int i = 1; i < 4 * N; i++) seg[i] = -1;
updseg(1, 0, N, dfsord[1], 0);
ch[C[1]] = 1;
B[0] = 1;
//printf("*\n");
for(int i = 1; i < N; i++) {
vector<pii> v;
for(int t = A[i]; t != 0;) {
//if(++xxxx < 100) printf("i = %d, t = %d\n", i, t);
int c = C[B[gseg(0, 1, N, dfsord[t], dfsr[t])]];
//if(++xxxx < 100) printf("c = %d\n", c);
v.push_back(make_pair(c, dep[t] - ch[c] + 1));
int k = anc(t, ch[c] - 1);
ch[c] = dep[t] + 1;
t = k;
}
lint ans = 0ll;
//for(auto a : v) printf("a = (%d, %d)\n", a.fi, a.se);
for(auto a : v) {
ans += lint(a.se) * gbit(a.fi - 1);
updbit(a.fi, a.se);
}
for(auto a : v) updbit(a.fi, -a.se);
ch[C[B[i]]] = 1;
updseg(1, 0, N, dfsord[B[i]], i);
printf("%lld\n", ans);
}
return 0;
}
Compilation message
construction.cpp: In function 'int main()':
construction.cpp:58:7: warning: ignoring return value of 'int scanf(const char*, ...)', declared with attribute warn_unused_result [-Wunused-result]
scanf("%d", &N);
~~~~~^~~~~~~~~~
construction.cpp:59:35: warning: ignoring return value of 'int scanf(const char*, ...)', declared with attribute warn_unused_result [-Wunused-result]
for(int i = 1; i <= N; i++) scanf("%d", C + i);
~~~~~^~~~~~~~~~~~~
construction.cpp:60:34: warning: ignoring return value of 'int scanf(const char*, ...)', declared with attribute warn_unused_result [-Wunused-result]
for(int i = 1; i < N; i++) scanf("%d%d", A + i, B + i);
~~~~~^~~~~~~~~~~~~~~~~~~~~~
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
4 ms |
2780 KB |
Output is correct |
2 |
Runtime error |
882 ms |
262144 KB |
Execution killed with signal 9 (could be triggered by violating memory limits) |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
4 ms |
2780 KB |
Output is correct |
2 |
Runtime error |
882 ms |
262144 KB |
Execution killed with signal 9 (could be triggered by violating memory limits) |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
4 ms |
2780 KB |
Output is correct |
2 |
Runtime error |
882 ms |
262144 KB |
Execution killed with signal 9 (could be triggered by violating memory limits) |
3 |
Halted |
0 ms |
0 KB |
- |