#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 1e5 + 3;
vector<int>adj_1[N],adj[N];
bool a[N];
int arr[N];
ll dp[2][N];
void direct(int node, int prev){
if(prev != -1){
adj[prev].push_back(node);
}
for(int child : adj_1[node]){
if(child != prev){
direct(child, node);
}
}
}
ll dfs(int node, bool b){
if(dp[b][node] != -1)
return dp[b][node];
ll ret = a[node] == b;
for(int child : adj[node]){
ret += dfs(child, b ^ a[node]);
}
//cout<<b<<' '<<node<<' '<<ret<<'\n';
return dp[b][node] = ret;
}
int main() {
int n;
cin>>n;
for(int i = 1; i <= n; i++)
cin>>arr[i];
for(int i = 1; i < n; i++){
int a,b;
cin>>a>>b;
adj_1[a].push_back(b);
adj_1[b].push_back(a);
}
direct(1,-1);
ll ans = 0;
for(int i = 0; (1 << i) < 3e6; i++){
//cout<<"FOR THE "<<i<<"-TH BIT:\n";
//cout<<"a: ";
memset(dp, -1, sizeof(dp));
for(int j = 1; j <= n; j++){
a[j]=!!(arr[j]&(1<<i));
//cout<<a[j]<<' ';
}
// cout<<'\n';
for(int j = 1; j <= n; j++){
ll tot0=0,tot1=0;
for(int child : adj[j]){
tot0 += dfs(child, 0);
tot1 += dfs(child, 1);
}
ll res = 0;
for(int child : adj[j]){
ll cnt0 = dfs(child, 0);
ll cnt1 = dfs(child, 1);
if(a[j]){
res += cnt0 * (tot0 - cnt0) + cnt1 * (tot1 - cnt1);
}else{
res += cnt0 * (tot1 - cnt1) + cnt1 * (tot0 - cnt0);
}
}
//cout<<j<<": "<<tot0<<' '<<tot1<<'\n';
ans += (res / 2 + (a[j] ? tot0 : tot1) + (a[j]==1)) * (1ll << i);
}
//cout<<ans<<'\n';
}
cout<<ans;
return 0;
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
4 ms |
6484 KB |
Output is correct |
2 |
Correct |
4 ms |
6484 KB |
Output is correct |
3 |
Correct |
4 ms |
6484 KB |
Output is correct |
4 |
Correct |
7 ms |
6612 KB |
Output is correct |
5 |
Correct |
5 ms |
6612 KB |
Output is correct |
6 |
Correct |
207 ms |
23016 KB |
Output is correct |
7 |
Correct |
187 ms |
22944 KB |
Output is correct |
8 |
Correct |
194 ms |
15516 KB |
Output is correct |
9 |
Correct |
216 ms |
14848 KB |
Output is correct |
10 |
Correct |
236 ms |
13592 KB |
Output is correct |