#include <iostream>
#include <vector>
using namespace std;
const int N = 2e5 + 5;
long long v[N], a[N], cnt1[N][32] = { 0 }, cnt0[N][32] = { 0 };
int currcnt1[32] = { 0 }, currcnt0[32] = { 0 };
long long ans = 0;
vector<long long> g[N];
void dfs(int node, int parent)
{
a[node] = a[parent] ^ v[node];
for (auto i : g[node])
{
if (i == parent)
{
continue;
}
dfs(i, node);
for (int bit = 30;bit >= 0;--bit)
{
long long k = (1ll << bit);
if ((v[node] & k) > 0)
{
ans += ((cnt1[i][bit] * currcnt1[bit]) + (cnt0[i][bit] * currcnt0[bit])) * k;
if ((a[node] & k) > 0)
{
ans += cnt1[i][bit]*k;
}
else
{
ans += cnt0[i][bit]*k;
}
}
else
{
ans += ((cnt1[i][bit] * currcnt0[bit]) + (cnt0[i][bit] * currcnt1[bit])) * k;
if ((a[node] & k) > 0)
{
ans += cnt0[i][bit]*k;
}
else
{
ans += cnt1[i][bit]*k;
}
}
currcnt1[bit] += cnt1[i][bit];
currcnt0[bit] += cnt0[i][bit];
}
}
for (int bit = 30;bit >= 0;--bit)
{
cnt1[node][bit] += currcnt1[bit];
cnt0[node][bit] += currcnt0[bit];
if (((1 << bit) & a[node]) > 0)
{
cnt1[node][bit]++;
}
else
{
cnt0[node][bit]++;
}
currcnt0[bit] = 0;
currcnt1[bit] = 0;
}
}
int main()
{
int n;
long long s = 0;
cin >> n;
for (int i = 1;i <= n;++i)
{
cin >> v[i];
s += v[i];
}
for (int i = 0;i < n-1;++i)
{
int x, y;
cin >> x >> y;
g[x].push_back(y);
g[y].push_back(x);
}
dfs(1, 0);
cout << ans+s << endl;
}
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |