#include<bits/stdc++.h>
#define rep(a,b,c) for(ll a=b;a<=c;++a)
#define ll long long
#define ff first
#define ss second
#define mp make_pair
using namespace std;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
const ll N=1e5+5,inf=1e18;
ll n,C[N],A[N],B[N],sz[N],par[N],hson[N],D[N],tp[N],pos[N],timer,cnt,num[N],BIT[N];
vector<ll> adj[N];
vector<pll> RLE[N];
vector<pll> tmp,res;
ll lowbit(ll x){return x&(-x);}
void modify(ll pos,ll v){
for(ll i=pos;i<=n;i+=lowbit(i)) BIT[i]+=v;
}
ll query(ll l,ll r){
ll res=0;
for(ll i=r;i>=1;i-=lowbit(i)) res+=BIT[i];
for(ll i=l-1;i>=1;i-=lowbit(i)) res-=BIT[i];
return res;
}
void DIS_BS(){
vector<ll> T;
rep(i,1,n) T.push_back(C[i]);
sort(T.begin(),T.end());
T.resize(unique(T.begin(),T.end())-T.begin());
rep(i,1,n) C[i]=lower_bound(T.begin(),T.end(),C[i])-T.begin()+1;
}
void dfs_sz(ll u,ll lst){
sz[u]=1;
hson[u]=-1;
for(ll v:adj[u]){
if(v==lst) continue;
D[v]=D[u]+1;
par[v]=u;
dfs_sz(v,u);
sz[u]+=sz[v];
if(hson[u]==-1||sz[v]>sz[hson[u]]) hson[u]=v;
}
}
void Link(ll u,ll top){
tp[u]=top;
pos[u]=++timer;
if(!num[top]) num[top]=++cnt;
if(hson[u]!=-1) Link(hson[u],top);
for(ll v:adj[u]){
if(v==par[u]||v==hson[u]) continue;
Link(v,v);
}
RLE[num[top]].push_back(mp(C[u],1));
}
void HLD(){
dfs_sz(1,-1);
Link(1,1);
}
void left(ll &a,ll &b,ll v){
if(D[a]<D[b]) swap(a,b);
ll step=D[a]-D[b]+1,id=num[tp[a]],cur=step;
while(cur-RLE[id].back().ss>=0){
tmp.push_back(RLE[id].back());
cur-=RLE[id].back().ss;
RLE[id].pop_back();
}
if(cur>0){
tmp.push_back(mp(RLE[id].back().ff,cur));
RLE[id].back().ss-=cur;
}
while(!tmp.empty()){
res.push_back(tmp.back());
tmp.pop_back();
}
while(!RLE[id].empty()&&RLE[id].back().ff==C[v]){
step+=RLE[id].back().ss;
RLE[id].pop_back();
}
RLE[id].push_back(mp(C[v],step));
vector<pll>().swap(tmp);
}
void solve(ll u,ll v){
ll a=u,b=1;
while(tp[a]!=tp[b]){
if(D[tp[a]]<D[tp[b]]) swap(a,b);
ll step=D[a]-D[tp[a]]+1,id=num[tp[a]],cur=step;
while(cur-RLE[id].back().ss>=0){
tmp.push_back(RLE[id].back());
cur-=RLE[id].back().ss;
RLE[id].pop_back();
}
if(cur>0){
tmp.push_back(mp(RLE[id].back().ff,cur));
RLE[id].back().ss-=cur;
}
while(!tmp.empty()){
res.push_back(tmp.back());
tmp.pop_back();
}
while(!RLE[id].empty()&&RLE[id].back().ff==C[v]){
step+=RLE[id].back().ss;
RLE[id].pop_back();
}
RLE[id].push_back(mp(C[v],step));
vector<pll>().swap(tmp);
a=par[tp[a]];
}
left(a,b,v);
ll ans=0;
while(!res.empty()){
ll x=res.back().ff,c=res.back().ss;
res.pop_back();
ans+=query(x+1,n)*c;
modify(x,c);
}
vector<pll>().swap(res);
cout<<ans<<'\n';
}
void init(){
fill(BIT+1,BIT+n+1,0);
}
signed main(){
ios::sync_with_stdio(0);cin.tie(0);
cin>>n;
rep(i,1,n) cin>>C[i];
DIS_BS();
rep(i,1,n-1){
cin>>A[i]>>B[i];
adj[A[i]].push_back(B[i]);
adj[B[i]].push_back(A[i]);
}
HLD();
rep(o,1,n-1){
ll u=A[o],v=B[o];
init();
solve(u,v);
}
return 0;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |