#include<bits/stdc++.h>
using namespace std;
const long long inf = (long long) 1e18 + 10;
const int inf1 = (int) 1e9 + 10;
#define int long long
#define dbl long double
#define endl '\n'
#define sc second
#define fr first
#define mp make_pair
#define pb push_back
#define all(x) x.begin(), x.end()
#define maxn 100100
int n, a[maxn], b[maxn], c[maxn], sz[maxn], pai[maxn], ant[maxn], h[maxn], bit[maxn];
vector<int> g[maxn];
vector<pair<pair<int,int>,int>> bl[maxn];
void att(int pos, int val) {
for(int i = pos; i <= n; i+= i&-i) {
bit[i]+= val;
}
}
int qrr(int pos) {
int val = 0;
for(int i = pos; i > 0; i-= i&-i) {
val+= bit[i];
}
return val;
}
void dfshld(int u) {
pair<int,int> mx;
for(auto v : g[u]) {
mx = max(mx,mp(sz[v],v));
}
if(mx.sc == 0) return;
pai[mx.sc] = pai[u];
ant[mx.sc] = u;
dfshld(mx.sc);
for(auto v : g[u]) {
if(v == mx.sc) continue;
pai[v] = v;
ant[v] = u;
dfshld(v);
}
}
void dfssz(int u, int ant) {
sz[u] = 1;
for(auto v : g[u]) {
if(v == ant) continue;
h[v] = h[u]+1;
dfssz(v,u);
sz[u]+= sz[v];
}
}
void solve() {
cin >> n;
vector<int> cc;
for(int i = 1; i <= n; i++) {
cin >> c[i];
c[i] = -c[i];
cc.pb(c[i]);
}
sort(all(cc));
cc.erase(unique(all(cc)),cc.end());
for(int i = 1; i <= n; i++) {
c[i] = upper_bound(all(cc),c[i]) - cc.begin();
}
for(int i = 1; i <= n-1; i++) {
cin >> a[i] >> b[i];
g[a[i]].pb(b[i]);
}
dfssz(1,1);
ant[1] = 0;
pai[1] = 1;
dfshld(1);
// for(int i = 1; i <= n; i++) {
// cout << i << " " << ant[i] << " " << pai[i] << endl;
// }
bl[1].pb(mp(mp(0,0),c[1]));
//mp(mp(altura minima,altura maxima),cor)
int cnt = 0;
for(int i = 1; i <= n-1; i++) {
int u = a[i];
vector<pair<pair<int,int>,int>> cols;
int v = u;
while(v != 0) {
// cout << "- " << v << endl;
if(pai[v] != v) {
for(auto x : bl[pai[v]]) {
if(x.fr.fr <= h[v]) cols.pb(mp(mp(x.fr.fr,min(h[v],x.fr.sc)),x.sc));
}
}
else {
cols.pb(mp(mp(h[v],h[v]),bl[v].back().sc));
}
v = ant[pai[v]];
}
// cout << b[i] << " " << c[b[i]] << endl;
sort(all(cols));
int ans = 0;
for(auto X : cols) {
int qtd = X.fr.sc-X.fr.fr+1;
int col = X.sc;
// cout << " " << X.fr.fr << " " << X.fr.sc << " " << col << endl;
ans+= qtd*qrr(col-1);
att(col,qtd);
cnt++;
}
cout << ans << endl;
int ant1 = 0;
for(auto X : cols) {
int qtd = X.fr.sc-X.fr.fr+1;
int col = X.sc;
if(ant1 != col) cnt++;
ant1 = col;
att(col,-qtd);
}
//atualizacao
//bl[i] tem que estar em ordem decrescente, com o bloco do pai sendo o ultimo
v = b[i];
int newc = c[b[i]];
while(v != 0) {
// cout << " " << v << endl;
while(bl[pai[v]].size() && bl[pai[v]].back().fr.sc <= h[v]) bl[pai[v]].pop_back();
if(bl[pai[v]].size() != 0) {
auto x = bl[pai[v]].back();
// cout << " " << x.fr.fr << " " << x.fr.sc << " " << x.sc << " " << h[v] << endl;
bl[pai[v]].pop_back();
bl[pai[v]].pb(mp(mp(h[v]+1,x.fr.sc),x.sc));
}
bl[pai[v]].pb(mp(mp(h[pai[v]],h[v]),newc));
v = ant[pai[v]];
}
}
assert(cnt <= n*5);
}
int32_t main() {
ios::sync_with_stdio(false); cin.tie(0);
// freopen("in.in", "r", stdin);
//freopen("out.out", "w", stdout);
int tt = 1;
// cin >> tt;
while(tt--) solve();
}
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Correct |
2 ms |
4940 KB |
Output is correct |
2 |
Correct |
3 ms |
4940 KB |
Output is correct |
3 |
Correct |
3 ms |
5068 KB |
Output is correct |
4 |
Runtime error |
7 ms |
10188 KB |
Execution killed with signal 6 |
5 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Correct |
2 ms |
4940 KB |
Output is correct |
2 |
Correct |
3 ms |
4940 KB |
Output is correct |
3 |
Correct |
3 ms |
5068 KB |
Output is correct |
4 |
Runtime error |
7 ms |
10188 KB |
Execution killed with signal 6 |
5 |
Halted |
0 ms |
0 KB |
- |
# |
Verdict |
Execution time |
Memory |
Grader output |
1 |
Correct |
2 ms |
4940 KB |
Output is correct |
2 |
Correct |
3 ms |
4940 KB |
Output is correct |
3 |
Correct |
3 ms |
5068 KB |
Output is correct |
4 |
Runtime error |
7 ms |
10188 KB |
Execution killed with signal 6 |
5 |
Halted |
0 ms |
0 KB |
- |