#include <bits/stdc++.h>
using namespace std;
using ll = long long;
using ii = pair<int, int>;
const int maxn = 1e5 + 5;
int n, a[maxn];
vector<int> g[maxn];
int dp[maxn][2];
int val[maxn];
ll contrib = 0;
void dfs(int u = 1, int dad = -1) {
for(int v : g[u]) if(v ^ dad) {
dfs(v, u);
}
dp[u][0] = dp[u][1] = 0;
dp[u][val[u]] = 1;
contrib += val[u];
for(int v : g[u]) if(v ^ dad) {
for(int t = 0; t < 2; t++) {
contrib += dp[v][t ^ 1] * dp[u][t];
dp[u][t] += dp[v][t ^ val[u]];
}
}
}
int main(int argc, char const *argv[])
{
#ifdef LOCAL
freopen("in", "r", stdin);
#endif
cin >> n;
for(int i = 1; i <= n; i++) {
cin >> a[i];
}
for(int i = 1; i < n; i++) {
int u, v;
cin >> u >> v;
g[u].emplace_back(v);
g[v].emplace_back(u);
}
ll ans = 0;
for(int i = 0; i < 22; i++) {
for(int j = 1; j <= n; j++) {
if(a[j] & (1 << i)) {
val[j] = 1;
} else {
val[j] = 0;
}
}
contrib = 0;
dfs();
ans += contrib * (1LL << i);
}
cout << ans << endl;
return 0;
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
4 ms |
2680 KB |
Output is correct |
2 |
Correct |
4 ms |
2680 KB |
Output is correct |
3 |
Correct |
4 ms |
2680 KB |
Output is correct |
4 |
Correct |
6 ms |
2808 KB |
Output is correct |
5 |
Correct |
6 ms |
2808 KB |
Output is correct |
6 |
Correct |
240 ms |
17156 KB |
Output is correct |
7 |
Correct |
236 ms |
17016 KB |
Output is correct |
8 |
Correct |
242 ms |
10744 KB |
Output is correct |
9 |
Correct |
251 ms |
9976 KB |
Output is correct |
10 |
Correct |
337 ms |
9356 KB |
Output is correct |