#include <bits/stdc++.h>
#define _CRT_SECURE_NO_WARNINGS
using namespace std;
#define all(x) x.begin(), x.end()
#define rall(x) x.rbegin(), x.rend()
#define uniq(x) x.resize(unique(all(x)) - x.begin());
#define sort_uniq(x) sort(all(x)), uniq(x);
#define no_el(x, y) x.find(y) == x.end()
#define ll long long
#define ld long double
#define pii pair<int, int>
#define pll pair<ll, ll>
#define V vector
#define V2dll V<V<ll>>
#define V2dint V<V<int>>
#define V2dchar V<V<char>>
#define V2dbool V<V<bool>>
#define V3dll V<V<V<ll>>>
#define V3dint V<V<V<int>>>
#define V3dchar V<V<V<char>>>
#define lb lower_bound
#define ub upper_bound
#define pb push_back
#define eb emplace_back
#define FASTIO \
ios_base::sync_with_stdio(false); \
cin.tie(nullptr); \
cout.tie(nullptr);
#define INF INT32_MAX
#define blt __builtin_popcount
#define clr(x) x.clear()
#define ff first
#define ss second
#define popf pop_front
#define popb pop_back
#define sz(x) int(x.size())
#define rep(a, b, c, d) for (int a = b; a <= c; a += d)
#define repl(a, b, c, d) for (int a = b; a >= c; a -= d)
mt19937_64 rng(chrono::steady_clock().now().time_since_epoch().count());
const int N = 1e5 + 5;
const int LG = 22;
ll val[N];
ll a[N], cnt[N][LG + 1][2];
V<int> adj[N];
ll ans = 0;
void dfs_a(int node, int p) {
for (auto i : adj[node]) if (i != p) {
a[i] = a[node] ^ val[i];
dfs_a(i, node);
}
}
void dfs_cnt(int node, int p) {
rep(i, 0, LG, 1) {
if (a[node] & (1 << i)) {
cnt[node][i][0] = 0, cnt[node][i][1] = 1;
} else
cnt[node][i][0] = 1, cnt[node][i][1] = 0;
}
for (auto i : adj[node]) if (i != p) {
dfs_cnt(i, node);
rep(it, 0, LG, 1) {
cnt[node][it][0] += cnt[i][it][0];
cnt[node][it][1] += cnt[i][it][1];
}
}
}
void dfs_ans(int node, int p) {
for (auto i : adj[node]) if (i != p)
dfs_ans(i, node);
rep(bt, 0, LG, 1) {
ll cn[2];
cn[0] = cnt[node][bt][0], cn[1] = cnt[node][bt][1];
int df = (a[node] & (1 << bt)) >> bt;
if (df) --cn[1];
else --cn[0];
df = (val[node] & (1 << bt)) >> bt;
for (auto i : adj[node]) if (i != p) {
ans += cnt[i][bt][0] * (cn[df ^ 1] - cnt[i][bt][df ^ 1]) * (1 << bt);
ans += cnt[i][bt][1] * (cn[df] - cnt[i][bt][df]) * (1 << bt);
}
ans += df * (1 << bt);
ans += cn[df ^ 1] * (1 << bt);
}
}
int main()
{
FASTIO
int n; cin >> n;
rep(i, 1, n, 1) cin >> val[i];
rep(i, 1, n - 1, 1) {
int a, b;
cin >> a >> b;
adj[a].pb(b);
adj[b].pb(a);
}
a[1] = val[1];
dfs_a(1, 0);
//rep(i, 1, n, 1) cout << a[i] << " ";
//cout << endl;
dfs_cnt(1, 0);
//rep(node, 1, n, 1) {
// rep(i, 0, LG, 1) cout << cnt[node][i][1] << " ";
// cout << endl;
//}
dfs_ans(1, 0);
cout << ans << "\n";
}
| # | Verdict | Execution time | Memory | Grader output |
|---|
| Fetching results... |