#include <iostream>
#include <vector>
using namespace std;
const int N = 2e5 + 5;
long long v[N], a[N];
int cnt1[N][32] = { 0 }, cnt0[N][32] = { 0 };
int currcnt1[32] = { 0 }, currcnt0[32] = { 0 };
long long ans = 0;
vector<int> g[N];
void dfs(int node, int parent)
{
    a[node] = a[parent] ^ v[node];
    for (auto i : g[node])
    {
        if (i == parent)
        {
            continue;
        }
        dfs(i, node);
    }
    for(auto i:g[node])
    {
        for (int bit = 30;bit >= 0;--bit)
        {
            long long k = (1ll << bit);
            if ((v[node] & k) > 0)
            {
                
                ans += ((cnt1[i][bit] * currcnt1[bit]) + (cnt0[i][bit] * currcnt0[bit])) * k;
                if ((a[node] & k) > 0)
                {
                    ans += cnt1[i][bit]*k;
                }
                else
                {
                    ans += cnt0[i][bit]*k;
                }
            }
            else
            {
                
                ans += ((cnt1[i][bit] * currcnt0[bit]) + (cnt0[i][bit] * currcnt1[bit])) * k;
                if ((a[node] & k) > 0)
                {
                    ans += cnt0[i][bit]*k;
                }
                else
                {
                    ans += cnt1[i][bit]*k;
                }
            }
            currcnt1[bit] += cnt1[i][bit];
            currcnt0[bit] += cnt0[i][bit];
        }
    }
    for (int bit = 30;bit >= 0;--bit)
    {
        cnt1[node][bit] += currcnt1[bit];
        cnt0[node][bit] += currcnt0[bit];
        if (((1 << bit) & a[node]) > 0)
        {
            cnt1[node][bit]++;
        }
        else
        {
            cnt0[node][bit]++;
        }
        currcnt0[bit] = 0;
        currcnt1[bit] = 0;
    }
}
int main()
{
    int n;
    long long s = 0;
    cin >> n;
    for (int i = 1;i <= n;++i)
    {
        cin >> v[i];
        s += v[i];
    }
    for (int i = 0;i < n-1;++i)
    {
        int x, y;
        cin >> x >> y;
        g[x].push_back(y);
        g[y].push_back(x);
    }
    dfs(1, 0);    
    cout << ans+s << endl;
}
| # | Verdict | Execution time | Memory | Grader output | 
|---|
| Fetching results... |