#include<bits/stdc++.h>
using namespace std;
#pragma GCC optimize ("Ofast")
#define all(x) x.begin() , x.end()
#define sze(x) (ll)(x.size())
#define mp(x , y) make_pair(x , y)
#define wall cerr<<"--------------------------------------"<<endl
typedef long long int ll;
typedef pair<ll , ll> pll;
typedef pair<int , int> pii;
typedef long double db;
typedef pair<pll , ll> plll;
typedef pair<int , pii> piii;
typedef pair<pll , pll> pllll;
const ll maxn = 5e5 + 16 , md = 1e9 + 7 , inf = 2e16;
ll gcd(ll a , ll b){
if(a < b) swap(a , b);
if(b == 0) return a;
return gcd(b , a % b);
}
ll tav(ll n , ll k){
ll res = 1;
while(k > 0){
if(k & 1){
res *= n; res %= md;
}
n *= n; n %= md;
k >>= 1;
}
return res;
}
struct fentree {
ll sz;
vector<ll> val;
void init(ll n){
sz = n;
val.assign(sz , 0);
return;
}
void add(ll i , ll k){
ll h = i;
while(h < sz){
val[h] += k;
h |= (h + 1);
}
return;
}
ll cal(ll i){
ll res = 0 , h = i;
while(h > -1){
res += val[h];
h &= (h + 1); h--;
}
return res;
}
void clear(){
val.clear();
return;
}
};
vector<pll> res , q;
struct segtree {
ll sz = 1;
vector<ll> val;
vector<bool> all;
void init(ll n){
while(sz < n) sz <<= 1;
val.assign(sz << 1 , -1);
all.assign(sz << 1 , true);
return;
}
void set(ll l , ll r , ll k , ll x , ll lx , ll rx){
if(rx <= l || lx >= r) return;
if(rx <= r && lx >= l && all[x]){
res.push_back({val[x] , rx - lx});
val[x] = k;
return;
}
ll m = (rx + lx) >> 1 , ln = (x << 1) + 1 , rn = ln + 1;
set(l , r , k , rn , m , rx); set(l , r , k , ln , lx , m);
if(val[ln] == val[rn] && all[ln] && all[rn]){
val[x] = val[ln]; all[x] = true;
} else {
all[x] = false;
}
return;
}
void set(ll l , ll r , ll k){
set(l , r , k , 0 , 0 , sz);
return;
}
};
fentree ft;
segtree st;
vector<ll> adj[maxn] , v;
ll a[maxn] , pr[maxn] , z[maxn] , hc[maxn] , hp[maxn] , dis[maxn] , dep = 0 , lb[maxn] , cur = 0;
void aDFS(ll r , ll par){
dis[r] = dep++;
pr[r] = par;
z[r] = 1;
ll m = -1 , ind = -1;
for(auto i : adj[r]){
if(i == par) continue;
aDFS(i , r);
if(z[i] > m){
m = z[i];
ind = i;
}
}
hc[r] = ind;
dep--;
return;
}
void lDFS(ll r , ll par){
lb[r] = cur++;
if(hc[r] == -1) return;
hp[hc[r]] = hp[r];
lDFS(hc[r] , r);
for(auto i : adj[r]){
if(i == par || i == hc[r]) continue;
hp[i] = i;
lDFS(i , r);
}
return;
}
ll ans;
void upd(ll v){
res.clear(); q.clear(); ans = 0;
ll h = v;
while(h > -1){
st.set(lb[hp[h]] , lb[h] + 1 , a[v]);
h = pr[hp[h]];
}
ll rs = sze(res);
q.push_back(res[1]);
for(ll e = 2 ; e < rs ; e++){
if(res[e].first == res[e - 1].first){
q.back().second += res[e].second;
} else {
q.push_back(res[e]);
}
}
for(auto p : q){
ans += 1ll * p.second * ft.cal(p.first);
ft.add(p.first , p.second);
}
for(auto p : q){
ft.add(p.first , -p.second);
}
return;
}
vector<pll> ed;
int main(){
ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0);
ll n;
cin>>n;
for(ll i = 0 ; i < n ; i++){
cin>>a[i];
v.push_back(a[i]);
}
sort(all(v));
v.resize(distance(v.begin() , unique(all(v))));
ft.init(sze(v));
for(ll i = 0 ; i < n ; i++){
a[i] = lower_bound(all(v) , a[i]) - v.begin();
}
st.init(n);
st.set(0 , 1 , a[0]);
for(ll i = 1 ; i < n ; i++){
ll v , u;
cin>>v>>u; v--; u--;
adj[v].push_back(u); adj[u].push_back(v);
ed.push_back({v , u});
}
aDFS(0 , -1);
lDFS(0 , -1);
hp[0] = 0;
for(auto p : ed){
ll v = p.first , u = p.second , i;
if(pr[v] == u){
i = v;
} else {
i = u;
}
upd(i);
cout<<ans<<'\n';
}
return 0;
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
8 ms |
12108 KB |
Output is correct |
2 |
Incorrect |
8 ms |
12060 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
8 ms |
12108 KB |
Output is correct |
2 |
Incorrect |
8 ms |
12060 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
8 ms |
12108 KB |
Output is correct |
2 |
Incorrect |
8 ms |
12060 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |