#define _CRT_SECURE_NO_WARNINGS
#include <bits/stdc++.h>
#include <random>
#include <chrono>
#include <unordered_set>
using namespace std;
typedef long long ll;
#define UNIQUE_SORT(vec) do { \
sort((vec).begin(), (vec).end()); \
(vec).erase(unique((vec).begin(), (vec).end()), (vec).end()); \
} while(0)
#define yes cout << "YES" << endl
#define no cout << "NO" << endl
#define LB(c, x) distance((c).begin(), lower_bound(all(c), (x)))
#define UB(c, x) distance((c).begin(), upper_bound(all(c), (x)))
#define ss second
#define ff first
#define all(X) X.begin(), X.end()
#define rall(X) X.rbegin(), X.rend()
#define MIN(v) *min_element(all(v))
#define MAX(v) *max_element(all(v))
#define cinall(X) for(auto &i:X)cin >> i
#define printall(X) for(auto &i:X)cout << i
const int N = 2e5 + 5;
const int LOG = 30;
long long v[N];
vector<int>G[N];
long long cnt[N][LOG][2], xxor[N];
long long ans = 0;
void dfs(int node, int parent)
{
xxor[node] = xxor[parent] ^ v[node];
for (int i = 0; i < LOG; i++)
{
if (xxor[node] & (1ll << i))
{
cnt[node][i][1]++;
}
else
{
cnt[node][i][0]++;
}
}
for (auto i : G[node])
{
if (i == parent)continue;
dfs(i, node);
for (int j = 0; j < LOG; j++)
{
cnt[node][j][0] += cnt[i][j][0];
cnt[node][j][1] += cnt[i][j][1];
}
}
vector<int>cur_cnt(2, 0);
for (int j = 0; j < LOG; j++)
{
cur_cnt[0] = 0;
cur_cnt[1] = 0;
for (auto i : G[node])
{
if (i == parent)
continue;
if (v[node] & (1ll << j))
{
long long x = (1ll << j);
ans += x * cnt[i][j][0] * cur_cnt[0];
ans += x * cnt[i][j][1] * cur_cnt[1];
}
else
{
long long x = (1ll << j);
ans += x * cnt[i][j][0] * cur_cnt[1];
ans += x * cnt[i][j][1] * cur_cnt[0];
}
cur_cnt[0] += cnt[i][j][0];
cur_cnt[1] += cnt[i][j][1];
}
if (((xxor[node] ^ v[node]) & (1ll << j)))
{
long long x = (1ll << j);
ans += cnt[node][j][0] * x;
}
else
{
long long x = (1ll << j);
ans += cnt[node][j][1] * x;
}
}
}
void solve()
{
int n;
cin >> n;
for (int i = 1; i <= n; i++)
cin >> v[i];
for (int i = 1; i < n; i++)
{
int u, v;
cin >> u >> v;
G[u].push_back(v);
G[v].push_back(u);
}
dfs(1, 0);
cout << ans << endl;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
int t = 1;
//cin >> t;
while (t--)
{
solve();
cout << endl;
}
return 0;
}
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |