#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int mod=1e9+7;
const int N=1e5+1;
vector<int> g[N];
ll dp[N];
ll answ;
ll a[N];
ll pref[N];
ll cnt[N][31][2];
int n;
void dfs(int v,int p){
    pref[v]=pref[p]^a[v];
    for(int j=0;j<30;++j){
        if(pref[v]&(1<<j)){
            cnt[v][j][1]++;
        }
        else{
            cnt[v][j][0]++;
        }
    }
    for(auto i:g[v]){
        if(i==p){
            continue;
        }
        dfs(i,v);
        for(int j=0;j<30;++j){
            cnt[v][j][1]+=cnt[i][j][1];
            cnt[v][j][0]+=cnt[i][j][0];
        }
    }
    for(int j=0;j<30;++j){
        ll cntz[2];
        cntz[0]=0;
        cntz[1]=0;
        ll u=(1<<j);
        for(auto i:g[v]){
            if(i==p){
                continue;
            }
            if(((1<<j)&a[v])){
                answ+=cntz[0]*cnt[i][j][0]*u;
                answ+=cntz[1]*cnt[i][j][1]*u;
            }
            else{
                answ+=cntz[0]*cnt[i][j][1]*u;
                answ+=cntz[1]*cnt[i][j][0]*u;
            }
            cntz[0]+=cnt[i][j][0];
            cntz[1]+=cnt[i][j][1];
        }
        if(((1<<j)&pref[p])){
            answ+=cntz[0]*u;
        }
        else{
            answ+=cntz[1]*u;
        }
    }
}
int main(){
    cin>>n;
    for(int i=1;i<=n;++i){
        cin>>a[i];
        answ+=a[i];
    }
    for(int i=1;i<n;++i){
        int u,v;
        cin>>u>>v;
        g[u].push_back(v);
        g[v].push_back(u);
    }
    dfs(1,0);
    cout<<answ<<'\n';
    return 0;
}
| # | Verdict | Execution time | Memory | Grader output | 
|---|
| Fetching results... |