#include <bits/stdc++.h>
using namespace std;
#define int long long
/*
IDEA:
1. For each bit, change the values of the nodes to 1 or 0 depending on whether the bit is set or not.
2. Now, we need to find the no. of paths which contain an odd no. of 1s in the tree
// and then multiply that value by 2^i (if we are checking for the ith bit)
3. For finding the no. of paths, take any node and consider it to be the lca of the path.
4. Consider only 1 node, paths ending at lca and nodes containing 2 children of lca.
// (No need to consider the parent node as the lca will not remain the same)
5. If only lca, then add 1 if the value of the lca node is set to 1.
6. If ending at lca, add sum0 (sum of no. of paths of ending at each child such that no. of 1s is even)
// if the value of lca is 1, else add sum1 (no. of 1s is odd)
7. If consisting of two children, generalize point no. 6.
*/
const int N = 1e5 + 5;
int n;
int actual_values[N], cur_values[N];
vector<int> g[N];
int ending[N][2]; //ending paths and total paths
int dfs(int node, int par)
{
ending[node][0] = ending[node][1] = 0;
int sum_path = cur_values[node];
// ending[node][cur_values[node]]++;
int sum0 = 0, sum1 = 0;
for(int to: g[node])
{
if(to == par)
{
continue;
}
sum_path += dfs(to, node);
sum0 += ending[to][0];
sum1 += ending[to][1];
if(cur_values[node] == 1)
{
ending[node][1] += ending[to][0];
ending[node][0] += ending[to][1];
}
else
{
ending[node][1] += ending[to][1];
ending[node][0] += ending[to][0];
}
}
sum_path += ending[node][1];
// cout << "node = " << node << ", sum_path = " << sum_path << "\n";
ending[node][cur_values[node]]++;
int cool = 0;
for(int to: g[node])
{
if(to == par)
{
continue;
}
if(cur_values[node] == 1)
{
cool += (ending[to][1] * (sum1 - ending[to][1]));
cool += (ending[to][0] * (sum0 - ending[to][0]));
}
else
{
cool += (ending[to][1] * (sum0 - ending[to][0]));
cool += (ending[to][0] * (sum1 - ending[to][1]));
}
}
sum_path += cool / 2;
// cout << "total sum_path of node = " << node << ", is: " << sum_path << "\n";
return sum_path;
}
void Solve()
{
cin >> n;
for(int i = 1; i <= n; i++)
{
cin >> actual_values[i];
}
for(int i = 1; i < n; i++)
{
int u, v;
cin >> u >> v;
g[u].push_back(v);
g[v].push_back(u);
}
int ans = 0;
for(int cur_bit = 23; cur_bit >= 0; cur_bit--)
{
for(int i = 1; i <= n; i++)
{
if((1 << cur_bit) & actual_values[i])
{
cur_values[i] = 1;
}
else
{
cur_values[i] = 0;
}
// cout << cur_values[i] << " ";
}
// cout << "\n";
int res = dfs(1, -1);
// cout << "cur_bit = " << cur_bit << ", res = " << res << "\n";
ans += (res * (1 << cur_bit));
}
cout << ans << "\n";
}
int32_t main()
{
ios_base::sync_with_stdio(0);
cin.tie(0);
Solve();
return 0;
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
2 ms |
2644 KB |
Output is correct |
2 |
Correct |
2 ms |
2644 KB |
Output is correct |
3 |
Correct |
2 ms |
2644 KB |
Output is correct |
4 |
Correct |
3 ms |
2772 KB |
Output is correct |
5 |
Correct |
3 ms |
2644 KB |
Output is correct |
6 |
Correct |
120 ms |
21664 KB |
Output is correct |
7 |
Correct |
123 ms |
21712 KB |
Output is correct |
8 |
Correct |
109 ms |
13156 KB |
Output is correct |
9 |
Correct |
130 ms |
12272 KB |
Output is correct |
10 |
Correct |
251 ms |
11324 KB |
Output is correct |