#include <bits/stdc++.h>
using namespace std;
#define fi first
#define se second
const int nax=2e5+7;
vector<int> adj[nax];
vector<pair<int,int>> edges;
vector<int> colors(nax);
vector<int> tab[nax];
vector<pair<int,int>> comp[nax];
int sz[nax],tp[nax],parent[nax],depth[nax];
int precompute_sz(int x,int p){
sz[x]=1;
for(auto u:adj[x]){
if(u==p) continue;
parent[u]=x;
depth[u]=depth[x]+1;
sz[x]+=precompute_sz(u,x);
}
return sz[x];
}
void dfs_hld(int x,int p,int top){
tp[x]=top;
pair<int,int> big_child={-1,-1};
for(auto u:adj[x]){
if(u==p) continue;
big_child=max(big_child,make_pair(sz[u],u));
}
if(big_child.fi != -1 )dfs_hld(big_child.se,x,top);
for(auto u : adj[x]){
if(u==p||u==big_child.se) continue;
dfs_hld(u,x,u);
}
return;
}
void dfs(int x,int p){
for(auto u:adj[x]){
if(u==p) continue;
dfs(u,x);
}
tab[tp[x]].push_back(colors[x]);
}
vector<pair<int,int>> q;
void query(int x,int c){
//cout <<x<<" "<<tp[x]<<endl;
int cur=depth[x]-depth[tp[x]]+1;
int cnt=0;
vector<pair<int,int>> nab;
while(cnt<cur){
cnt+=comp[tp[x]].back().se;
nab.push_back(comp[tp[x]].back());
comp[tp[x]].pop_back();
}
if(cnt>cur){
comp[tp[x]].push_back(nab.back());
comp[tp[x]].back().se=cnt-cur;
nab.back().se-=comp[tp[x]].back().se;
}
comp[tp[x]].push_back({c,cur});
reverse(nab.begin(),nab.end());
for(auto u:nab) q.push_back(u);
if(tp[x] != 0) query(parent[tp[x]],c);
}
int segtree[nax*4];
bool lazy[nax*4];
void extend(int pos){
if(lazy[pos]==0) return;
segtree[pos*2+1]=0;
lazy[pos*2+1]=1;
segtree[pos*2+2]=0;
lazy[pos*2+2]=1;
lazy[pos]=0;
}
void update(int pos,int l,int r,int idx,int value){
if(l>r) return;
extend(pos);
if(l==r){
segtree[pos]+=value;
return;
}
int mid=(r+l)/2;
if(idx<=mid) update(pos*2+1,l,mid,idx,value);
else update(pos*2+2,mid+1,r,idx,value);
segtree[pos]=segtree[pos*2+1]+segtree[pos*2+2];
return;
}
int query(int pos,int l,int r,int left,int right){
if(l>r||l>right||r<left||right<left) return 0;
extend(pos);
if(left<=l&&r<=right) return segtree[pos];
int mid=(r+l)/2;
return query(pos*2+1,l,mid,left,right)+query(pos*2+2,mid+1,r,left,right);
}
int main() {
int n;
cin>>n;
map<int,vector<int>> mp;
for (int i = 0; i < n; ++i)
{
cin>>colors[i];
mp[colors[i]].push_back(i);
}
int j=0;
for(auto u:mp){
for(auto i:u.se) colors[i]=j;
j++;
}
for (int i = 0; i < n-1; ++i)
{
int x,y;
cin>>x>>y;
x--;y--;
edges.push_back({x,y});
adj[x].push_back(y);
adj[y].push_back(x);
}
precompute_sz(0,-1);
dfs_hld(0,-1,0);
dfs(0,-1);
for (int i = 0; i < n; ++i)
{
int lst=0;
for (int j = 1; j < tab[i].size(); ++j)
{
if(tab[i][j]!=tab[i][j-1]){
comp[i].push_back({tab[i][j-1],j-lst});
lst=j;
}
}
if(!tab[i].empty()) comp[i].push_back({tab[i].back(),tab[i].size()-lst});
}
for (int i = 0; i < n-1; ++i)
{
q.clear();
lazy[0]=1;
long long ans=0;
query(edges[i].se,colors[edges[i].se]);
for(int i=0;i<q.size();i++){
pair<int,int> u=q[i];
if(i==0&&u.se==1) continue;
else if(i==0) u.se-=1;
ans+=1ll*query(0,0,n-1,0,u.fi-1)*u.se;
update(0,0,n-1,u.fi,u.se);
}
cout <<ans<<endl;
}
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |