#include<bits/stdc++.h>
using namespace std;
vector<int> haha[100001];
vector<int> st(100001);
vector<int> pos(100001);
int bruh[100001][18];
vector<int> banana(100001);
vector<int> dep(100001);
vector<int> idk(200001);
vector<int> roll(0);
vector<pair<int,int>> seg(500001);
int z = 1;
void upd1(int a, int x) {
while(a < idk.size()) {
idk[a]+=x;
roll.push_back(a);
a+=(a&(-a));
}
}
int calc1(int a) {
int c = 0,sb = 0;
for(int i = 18; i >= 0; i--) {
if(c+(1 << i) <= a) {
c+=(1 << i);
sb+=idk[c];
}
}
return sb;
}
void dfs(int a, int t) {
st[a] = 1;
pos[a] = z;
z++;
if(t != -1) {
dep[a] = dep[t]+1;
}
bruh[a][0] = t;
for(int v: haha[a]) {
if(v != t) {
dfs(v,a);
st[a]+=st[v];
}
}
}
void upd(int l, int r, int x, int p, pair<int,int> c) {
if(l == r) {
seg[x] = c;
return;
}
int mid = (l+r)/2;
if(p <= mid) {
upd(l,mid,x*2,p,c);
}
else {
upd(mid+1,r,x*2+1,p,c);
}
seg[x] = max(seg[x*2],seg[x*2+1]);
}
pair<int,int> calc(int l, int r, int x, int ql, int qr) {
if(l == ql && r == qr) {
return seg[x];
}
int mid = (l+r)/2;
if(qr <= mid) {
return calc(l,mid,x*2,ql,qr);
}
else if(ql > mid) {
return calc(mid+1,r,x*2+1,ql,qr);
}
else {
pair<int,int> a = calc(l,mid,x*2,ql,mid);
pair<int,int> b = calc(mid+1,r,x*2+1,mid+1,qr);
return max(a,b);
}
}
int lift(int a, int br) {
for(int i = 17; i >= 0; i--) {
if(br&(1 << i)) {
a = bruh[a][i];
}
}
return a;
}
int main()
{
ios_base::sync_with_stdio(false);
cin.tie(NULL);
cout.tie(NULL);
int n,a,b;
cin >> n;
vector<pair<int,int>> comp(0);
for(int i = 1; i <= n; i++) {
cin >> a;
comp.push_back({a,i});
}
sort(comp.begin(),comp.end());
int y = 1;
for(int i = 0; i < comp.size(); i++) {
if(i > 0 && comp[i].first != comp[i-1].first) {
y++;
}
banana[comp[i].second] = y;
}
vector<pair<int,int>> edge(0);
vector<int> lol(n+1);
for(int i = 0; i < n-1; i++) {
cin >> a >> b;
edge.push_back({a,b});
haha[a].push_back(b);
}
dfs(1,-1);
for(int i = 1; i < 18; i++) {
for(int j = 1; j <= n; j++) {
if(bruh[j][i-1] == -1) {
bruh[j][i] = -1;
}
else {
bruh[j][i] = bruh[bruh[j][i-1]][i-1];
}
}
}
upd(1,n,1,1,{1,1});
for(int i = 0; i < n-1; i++) {
a = edge[i].first;
b = edge[i].second;
lol[b] = 0;
int c = a;
vector<pair<int,int>> wut(0);
while(c != -1) {
int col = calc(1,n,1,pos[c],pos[c]+st[c]-1).second;
int e = dep[c]-lol[col]+1;
lol[col] = dep[c]+1;
wut.push_back({banana[col],e});
c = lift(c,e);
}
upd(1,n,1,pos[b],{i+2,b});
long long ans = 0;
for(int j = 0; j < wut.size(); j++) {
ans+=(long long)calc1(wut[j].first-1)*(long long)wut[j].second;
upd1(wut[j].first,wut[j].second);
}
cout << ans << "\n";
for(int v: roll) {
idk[v] = 0;
}
roll.clear();
}
return 0;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |