#include<bits/stdc++.h>
#define rep(a,b,c) for(ll a=b;a<=c;++a)
#define ll long long 
#define ff first 
#define ss second 
#define mp make_pair 
using namespace std;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
const ll N=1e5+5,inf=1e18;
ll n,C[N],A[N],B[N],sz[N],par[N],hson[N],D[N],tp[N],pos[N],timer,cnt,num[N],BIT[N];
vector<ll> adj[N];
vector<pll> RLE[N];
vector<pll> tmp,res;
ll lowbit(ll x){return x&(-x);}
void modify(ll pos,ll v){
    for(ll i=pos;i<=n;i+=lowbit(i)) BIT[i]+=v;
}
ll query(ll l,ll r){
    ll res=0;
    for(ll i=r;i>=1;i-=lowbit(i)) res+=BIT[i];
    for(ll i=l-1;i>=1;i-=lowbit(i)) res-=BIT[i];
    return res;
}
void DIS_BS(){
    vector<ll> T;
    rep(i,1,n) T.push_back(C[i]);
    sort(T.begin(),T.end());
    T.resize(unique(T.begin(),T.end())-T.begin());
    rep(i,1,n) C[i]=lower_bound(T.begin(),T.end(),C[i])-T.begin()+1;
}
void dfs_sz(ll u,ll lst){
    sz[u]=1;
    hson[u]=-1;
    for(ll v:adj[u]){
        if(v==lst) continue;
        D[v]=D[u]+1;
        par[v]=u;
        dfs_sz(v,u);
        sz[u]+=sz[v];
        if(hson[u]==-1||sz[v]>sz[hson[u]]) hson[u]=v;
    }
}
void Link(ll u,ll top){
    tp[u]=top;
    pos[u]=++timer;
    if(!num[top]) num[top]=++cnt;
    if(hson[u]!=-1) Link(hson[u],top);
    for(ll v:adj[u]){
        if(v==par[u]||v==hson[u]) continue;
        Link(v,v);
    }
    RLE[num[top]].push_back(mp(C[u],1));
}
void HLD(){
    dfs_sz(1,-1);
    Link(1,1);
}
void left(ll &a,ll &b,ll v){
    if(D[a]<D[b]) swap(a,b);
    ll step=D[a]-D[b]+1,id=num[tp[a]],cur=step;
    while(cur-RLE[id].back().ss>=0){
        tmp.push_back(RLE[id].back());
        cur-=RLE[id].back().ss;
        RLE[id].pop_back();
    }
    if(cur>0){
        tmp.push_back(mp(RLE[id].back().ff,cur));
        RLE[id].back().ss-=cur;
    }
    while(!tmp.empty()){
        res.push_back(tmp.back());
        tmp.pop_back();
    }
    while(!RLE[id].empty()&&RLE[id].back().ff==C[v]){
        step+=RLE[id].back().ss;
        RLE[id].pop_back();
    }
    RLE[id].push_back(mp(C[v],step));
    vector<pll>().swap(tmp);
}
void solve(ll u,ll v){
    ll a=u,b=1;
    while(tp[a]!=tp[b]){
        if(D[tp[a]]<D[tp[b]]) swap(a,b);
        ll step=D[a]-D[tp[a]]+1,id=num[tp[a]],cur=step;
        while(cur-RLE[id].back().ss>=0){
            tmp.push_back(RLE[id].back());
            cur-=RLE[id].back().ss;
            RLE[id].pop_back();
        }
        if(cur>0){
            tmp.push_back(mp(RLE[id].back().ff,cur));
            RLE[id].back().ss-=cur;
        }
        while(!tmp.empty()){
            res.push_back(tmp.back());
            tmp.pop_back();
        }
        while(!RLE[id].empty()&&RLE[id].back().ff==C[v]){
            step+=RLE[id].back().ss;
            RLE[id].pop_back();
        }
        RLE[id].push_back(mp(C[v],step));
        vector<pll>().swap(tmp);
        a=par[tp[a]];
    }
    left(a,b,v);
    ll ans=0;
    while(!res.empty()){
        ll x=res.back().ff,c=res.back().ss;
        res.pop_back();
        ans+=query(x+1,n)*c;
        modify(x,c);
    }
    vector<pll>().swap(res);
    cout<<ans<<'\n';
}
void init(){
    fill(BIT+1,BIT+n+1,0);
}
signed main(){
    ios::sync_with_stdio(0);cin.tie(0);
    cin>>n;
    rep(i,1,n) cin>>C[i];
    DIS_BS();
    rep(i,1,n-1){
        cin>>A[i]>>B[i];
        adj[A[i]].push_back(B[i]);
        adj[B[i]].push_back(A[i]);
    }
    HLD();
    rep(o,1,n-1){
        ll u=A[o],v=B[o];
        init();
        solve(u,v);
    }
    return 0;
}
| # | Verdict  | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict  | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict  | Execution time | Memory | Grader output | 
|---|
| Fetching results... |