#include <bits/stdc++.h>
using namespace std;
using ll = long long;
using pii = pair<int, int>;
#define pb push_back
#define ff first
#define ss second
#define ins insert
#define arr3 array<int, 3>
struct FT{
vector<int> bit;
int n;
FT(int ns){
n = ns;
bit.resize(n + 1);
}
void upd(int v, int k){
while (v <= n){
bit[v] += k;
v |= (v + 1);
}
}
int get(int v){
int out = 0;
while (v > 0){
out += bit[v];
v = (v & (v + 1)) - 1;
}
return out;
}
int get(int l, int r){
return get(r) - get(l - 1);
}
};
int main(){
ios_base::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
int n; cin>>n;
vector<int> a(n + 1), all;
for (int i = 1; i <= n; i++){
cin>>a[i];
all.pb(a[i]);
}
sort(all.begin(), all.end());
map<int, int> mp;
int i = 0, cnt = 0;
while (i < n){
int j = i + 1;
while (j < n && all[i] == all[j]){
j++;
}
mp[all[i]] = ++cnt;
i = j;
}
for (int i = 1; i <= n; i++) a[i] = mp[a[i]];
vector<int> x(n), y(n), g[n + 1];
for (int i = 1; i < n; i++){
cin>>x[i]>>y[i];
g[x[i]].pb(y[i]);
g[y[i]].pb(x[i]);
}
vector<int> sz(n + 1), d(n + 1), h(n + 1), p(n + 1);
function<void(int, int)> fill = [&](int v, int pr){
p[v] = pr;
d[v] = d[pr] + 1;
sz[v] = 1;
for (int i: g[v]){
if (i == pr) continue;
fill(i, v);
if (sz[i] > sz[h[v]]){
h[v] = i;
}
sz[v] += sz[i];
}
};
fill(1, 0);
vector<int> head(n + 1), pos(n + 1);
int timer = 0;
function<void(int, int)> fill_hld = [&](int v, int k){
head[v] = k;
pos[v] = ++timer;
if (!h[v]) return;
fill_hld(h[v], k);
for (int i: g[v]){
if (pos[i]) continue;
fill_hld(i, i);
}
};
fill_hld(1, 1);
set<arr3> st;
for (int i = 1; i <= n; i++) st.ins({pos[i], pos[i], a[i]});
vector<pii> ch;
vector<arr3> er, sg;
FT T(n);
for (int i = 1; i < n; i++){
int v = x[i];
while (v > 0){
ch.pb({pos[head[v]], pos[v]});
v = p[head[v]];
}
reverse(ch.begin(), ch.end());
for (auto [l, r]: ch){
auto it = st.lower_bound({l, 0});
while (it != st.end() && (*it)[1] <= r){
er.pb(*it);
it++;
}
for (auto f: er) st.erase(f);
it = st.lower_bound({l + 1, 0});
if (it != st.begin()){
it--;
auto [l1, r1, k] = *it;
if (l1 <= l && l <= r1){
st.erase(it);
if (l1 < l){
st.ins({l1, l - 1, k});
}
if (r < r1){
st.ins({r + 1, r1, k});
}
sg.pb({max(l1, l), min(r1, r), k});
}
}
for (auto f: er) sg.pb(f);
er.clear();
it = st.lower_bound({r + 1, 0});
if (it != st.begin()){
it--;
auto [l1, r1, k] = (*it);
if (l1 <= r && r <= r1){
st.erase(it);
assert(l1 >= l);
if (r < r1){
st.ins({r + 1, r1, k});
}
sg.pb({max(l1, l), min(r1, r), k});
}
}
}
ll out = 0;
for (auto [l, r, k]: sg){
out += T.get(k + 1, n);
T.upd(k, (r - l + 1));
}
for (auto [l, r, k]: sg) T.upd(k, -(r - l + 1));
cout<<out<<"\n";
for (auto [l, r]: ch) st.ins({l, r, a[y[i]]});
ch.clear(); sg.clear();
}
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
1 ms |
348 KB |
Output is correct |
2 |
Correct |
0 ms |
348 KB |
Output is correct |
3 |
Incorrect |
1 ms |
348 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
1 ms |
348 KB |
Output is correct |
2 |
Correct |
0 ms |
348 KB |
Output is correct |
3 |
Incorrect |
1 ms |
348 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
1 ms |
348 KB |
Output is correct |
2 |
Correct |
0 ms |
348 KB |
Output is correct |
3 |
Incorrect |
1 ms |
348 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |