#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int mod=1e9+7;
const int N=1e5+1;
vector<int> g[N];
ll dp[N];
ll answ;
ll a[N];
ll pref[N];
ll cnt[N][31][2];
int n;
void dfs(int v,int p){
pref[v]=pref[p]^a[v];
for(int j=0;j<30;++j){
if(pref[v]&(1<<j)){
cnt[v][j][1]++;
}
else{
cnt[v][j][0]++;
}
}
for(auto i:g[v]){
if(i==p){
continue;
}
dfs(i,v);
for(int j=0;j<30;++j){
cnt[v][j][1]+=cnt[i][j][1];
cnt[v][j][0]+=cnt[i][j][0];
}
}
for(int j=0;j<30;++j){
ll cntz[2];
cntz[0]=0;
cntz[1]=0;
ll u=(1<<j);
for(auto i:g[v]){
if(i==p){
continue;
}
if(((1<<j)&a[v])){
answ+=cntz[0]*cnt[i][j][0]*u;
answ+=cntz[1]*cnt[i][j][1]*u;
}
else{
answ+=cntz[0]*cnt[i][j][1]*u;
answ+=cntz[1]*cnt[i][j][0]*u;
}
cntz[0]+=cnt[i][j][0];
cntz[1]+=cnt[i][j][1];
}
if(((1<<j)&pref[p])){
answ+=cntz[0]*u;
}
else{
answ+=cntz[1]*u;
}
}
}
int main(){
cin>>n;
for(int i=1;i<=n;++i){
cin>>a[i];
answ+=a[i];
}
for(int i=1;i<n;++i){
int u,v;
cin>>u>>v;
g[u].push_back(v);
g[v].push_back(u);
}
dfs(1,0);
cout<<answ<<'\n';
return 0;
}
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |