#define _CRT_SECURE_NO_WARNINGS
#include <bits/stdc++.h>
#include <random>
#include <chrono>
#include <unordered_set>
using namespace std;
typedef long long ll;
#define UNIQUE_SORT(vec) do { \
	sort((vec).begin(), (vec).end()); \
	(vec).erase(unique((vec).begin(), (vec).end()), (vec).end()); \
} while(0)
#define yes cout << "YES" << endl
#define no cout << "NO" << endl
#define LB(c, x) distance((c).begin(), lower_bound(all(c), (x)))
#define UB(c, x) distance((c).begin(), upper_bound(all(c), (x)))
#define ss second
#define ff first
#define all(X) X.begin(), X.end()
#define rall(X) X.rbegin(), X.rend()
#define MIN(v) *min_element(all(v))
#define MAX(v) *max_element(all(v))
#define cinall(X) for(auto &i:X)cin >> i
#define printall(X) for(auto &i:X)cout << i
const int N = 2e5 + 5;
const int LOG = 30;
long long v[N];
vector<int>G[N];
long long cnt[N][LOG][2], xxor[N];
long long ans = 0;
void dfs(int node, int parent)
{
	xxor[node] = xxor[parent] ^ v[node];
	for (int i = 0; i < LOG; i++)
	{
		if (xxor[node] & (1ll << i))
		{
			cnt[node][i][1]++;
		}
		else
		{
			cnt[node][i][0]++;
		}
	}
	for (auto i : G[node])
	{
		if (i == parent)continue;
		dfs(i, node);
		for (int j = 0; j < LOG; j++)
		{
			cnt[node][j][0] += cnt[i][j][0];
			cnt[node][j][1] += cnt[i][j][1];
		}
	}
	vector<int>cur_cnt(2, 0);
	for (int j = 0; j < LOG; j++)
	{
		cur_cnt[0] = 0;
		cur_cnt[1] = 0;
		for (auto i : G[node])
		{
			if (i == parent)
				continue;
			if (v[node] & (1ll << j))
			{
				long long x = (1ll << j);
				ans += x * cnt[i][j][0] * cur_cnt[0];
				ans += x * cnt[i][j][1] * cur_cnt[1];
			}
			else
			{
				long long x = (1ll << j);
				ans += x * cnt[i][j][0] * cur_cnt[1];
				ans += x * cnt[i][j][1] * cur_cnt[0];
			}
			cur_cnt[0] += cnt[i][j][0];
			cur_cnt[1] += cnt[i][j][1];
		}
		if (((xxor[node] ^ v[node]) & (1ll << j)))
		{
			long long x = (1ll << j);
			ans += cnt[node][j][0] * x;
		}
		else
		{
			long long x = (1ll << j);
			ans += cnt[node][j][1] * x;
		}
	}
}
void solve()
{
	int n;
	cin >> n;
	for (int i = 1; i <= n; i++)
		cin >> v[i];
	for (int i = 1; i < n; i++)
	{
		int u, v;
		cin >> u >> v;
		G[u].push_back(v);
		G[v].push_back(u);
	}
	dfs(1, 0);
	cout << ans << endl;
}
int main() {
	ios::sync_with_stdio(false);
	cin.tie(0);
	cout.tie(0);
	int t = 1;
	//cin >> t;
	while (t--)
	{
		solve();
		cout << endl;
	}
	return 0;
}
| # | Verdict | Execution time | Memory | Grader output | 
|---|
| Fetching results... |