#include<bits/stdc++.h>
using namespace std;
int const N = 1e5 + 5;
typedef long long LL;
int a[N];
vector<int> graph[N];
LL cnt0[N], cnt1[N];
void dfs(int x, int tw, int fa = -1){
cnt0[x] = 0;
cnt1[x] = 0;
for(auto i : graph[x]){
if(i == fa)
continue;
dfs(i, tw, x);
cnt0[x] += cnt0[i];
cnt1[x] += cnt1[i];
}
if(a[x] & (1 << tw)){
swap(cnt0[x], cnt1[x]);
cnt1[x]++;
}else{
cnt0[x]++;
}
}
LL ans = 0;
void dfs2(int x, int tw, LL up0 = 0, LL up1 = 0, int fa = -1){
ans += cnt1[x] * (1 << tw);
if(a[x] & (1 << tw))
swap(up0, up1);
ans += up1 * (1 << tw);
for(auto i : graph[x]){
if(i == fa)
continue;
if(a[x] & (1 << tw)){
dfs2(i, tw, up0 + cnt0[x] - cnt1[i], up1 + cnt1[x] - cnt0[i], x);
}else{
dfs2(i, tw, up0 + cnt0[x] - cnt0[i], up1 + cnt1[x] - cnt1[i], x);
}
}
}
int main(){
ios::sync_with_stdio(0);
cin.tie(0);
int n;
cin >> n;
int u, v;
LL sum = 0;
for(int i = 1; i <= n; i++){
cin >> a[i];
sum += a[i];
}
for(int i = 0; i < n - 1; i++){
cin >> u >> v;
graph[u].push_back(v);
graph[v].push_back(u);
}
for(int i = 0; i <= 22; i++){
dfs(1, i);
dfs2(1, i);
}
cout << (ans - sum) / 2 + sum << '\n';
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
1 ms |
2644 KB |
Output is correct |
2 |
Correct |
1 ms |
2644 KB |
Output is correct |
3 |
Correct |
1 ms |
2644 KB |
Output is correct |
4 |
Correct |
2 ms |
2644 KB |
Output is correct |
5 |
Correct |
3 ms |
2644 KB |
Output is correct |
6 |
Correct |
172 ms |
18700 KB |
Output is correct |
7 |
Correct |
168 ms |
18776 KB |
Output is correct |
8 |
Correct |
156 ms |
9760 KB |
Output is correct |
9 |
Correct |
182 ms |
8792 KB |
Output is correct |
10 |
Correct |
339 ms |
7836 KB |
Output is correct |