#include <bits/stdc++.h>
using namespace std;
using ii = pair<int, int>;
using ll = long long;
using vi = vector<int>;
#define fi first
#define se second
#define all(v) begin(v), end(v)
#define sz(v) (int)(v.size())
struct interval {
int l, r, x;
bool operator < (const interval& o) const {
return l < o.l;
}
};
int main(int argc, char const *argv[])
{
#ifdef LOCAL
freopen("in2", "r", stdin);
#endif
int n; scanf("%d", &n);
vi C(n);
for(int i = 0; i < n; ++i) {
scanf("%d", &C[i]);
}
vi CC = C; sort(all(CC)); CC.erase(unique(all(CC)), CC.end());
for(int i = 0; i < n; ++i) {
C[i] = int(lower_bound(all(CC), C[i]) - CC.begin());
}
vi A(n-1), B(n-1);
vector<vi> g(n);
for(int i = 0; i < n-1; ++i) {
scanf("%d %d", &A[i], &B[i]);
A[i]--, B[i]--;
g[A[i]].emplace_back(B[i]);
g[B[i]].emplace_back(A[i]);
}
vi in(n), sub(n), who(n), par(n);
int tym = 0;
function<void(int,int)> dfs0 = [&](int u, int dad) {
par[u] = dad;
sub[u] = 1;
for(int v : g[u]) if(v != dad) {
dfs0(v, u);
sub[u] += sub[v];
}
};
dfs0(0, -1);
vi head(n);
function<void(int,int,int)> hld_build = [&](int u, int dad, int Head) {
who[in[u] = tym++] = u;
head[u] = Head;
ii big(0, -1);
for(int v : g[u]) if(v != dad) {
big = max(big, ii(sub[v], v));
}
if(big.fi) hld_build(big.se, u, Head);
for(int v : g[u]) if(v != dad and v != big.se) {
hld_build(v, u, v);
}
};
hld_build(0, -1, 0);
set<interval> st;
auto set_range = [&](int L, int R, int x) {
auto lit = --st.upper_bound({L,0,0});
auto rit = st.upper_bound({R,0,0});
int lb = lit -> l, rb = prev(rit) -> r;
int lx = lit -> x, rx = prev(rit) -> x;
vector<interval> ret;
for(auto it = lit; it != rit; ++it) {
ret.push_back({max(L, it -> l), min(R, it -> r), it -> x});
}
st.erase(lit, rit);
if(lb < L) st.insert(rit, {lb, L-1, lx});
st.insert(rit, {L, R, x});
if(rb > R) st.insert(rit, {R+1, rb, rx});
return ret;
};
for(int i = 0; i < n; ++i) {
st.insert({in[i], in[i], C[i]});
}
auto rev_append = [](vector<interval>& X, vector<interval> Y) {
reverse(all(Y));
X.insert(X.end(), all(Y));
};
vi cnt(n+5);
for(int j = 0; j < n-1; ++j) {
int u = A[j];
vector<interval> ls;
do {
rev_append(ls, set_range(in[head[u]], in[u], C[B[j]]));
u = par[head[u]];
} while(~u);
reverse(all(ls));
vi c(sz(ls)), v(sz(ls));
for(int i = 0; i < sz(ls); ++i) {
c[i] = ls[i].r - ls[i].l + 1;
v[i] = ls[i].x;
}
ll ans = 0;
for(int it = 0; it < 17; ++it) {
for(int i = 0; i < sz(v); ++i) {
if(~v[i] & 1) {
ans += cnt[v[i] | 1];
}
cnt[v[i]] += c[i];
}
for(int i = 0; i < sz(v); ++i) {
cnt[v[i]] -= c[i];
v[i] >>= 1;
}
}
printf("%lld\n", ans);
}
return 0;
}
Compilation message
construction.cpp: In function 'int main(int, const char**)':
construction.cpp:23:15: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
23 | int n; scanf("%d", &n);
| ~~~~~^~~~~~~~~~
construction.cpp:26:10: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
26 | scanf("%d", &C[i]);
| ~~~~~^~~~~~~~~~~~~
construction.cpp:35:10: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
35 | scanf("%d %d", &A[i], &B[i]);
| ~~~~~^~~~~~~~~~~~~~~~~~~~~~~
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
1 ms |
204 KB |
Output is correct |
2 |
Correct |
1 ms |
204 KB |
Output is correct |
3 |
Incorrect |
1 ms |
204 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
1 ms |
204 KB |
Output is correct |
2 |
Correct |
1 ms |
204 KB |
Output is correct |
3 |
Incorrect |
1 ms |
204 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
1 ms |
204 KB |
Output is correct |
2 |
Correct |
1 ms |
204 KB |
Output is correct |
3 |
Incorrect |
1 ms |
204 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |