#include "bits/stdc++.h"
using namespace std;
#define FAST ios_base::sync_with_stdio(false); cin.tie(0);
#define LLINF (long long) 1e18//1234567890987654321
#define INF 1234567890
#define pb push_back
#define ins insert
#define f first
#define s second
#define db 0
#define EPS (1e-7) //0.0000001 the value
#define PI (acos(-1))
#define MAXN 300006
#define MAXK 26
#define MAXX 15000006
#define ll long long int
#define ld long double
#define rep0(kk, l1, l2)for(ll kk = l1; kk < l2; kk++)
#define rep1(kk, l1, l2)for(ll kk = l1; kk <= l2; kk++)
#define foritr(itr, A) for(set<ll>::iterator itr = A.begin(); itr != A.end(); itr++)
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count()); //can be used by calling rng() or shuffle(A, A+n, rng)
#define FOR(ii, ss, ee) for(ll ii = ss; ii < ee; ii++)
#define cr(x) cerr << #x << " = " << x << "\n";
#define crA(x, A) cerr << #x << " = " << A[x] << "\n";
#define spacing if(db)cout << " ";
#define mmst(x, v) memset((x), v, sizeof ((x)));
#define bg(ms) (*ms.begin())
#define ed(ms) (*prev(ms.end(), 1))
#define addedge(a, b, c, v) v[(a)].pb(pi((b), (c))); v[(b)].pb(pi((a), (c)))
#define ph push
#define btinpct(x) __builtin_popcountll(x)
#define p2(x) (1LL<<(x))
#define all(x) (x).begin(), (x).end()
#define lbd(x, y) lower_bound(all(x), y)
#define ubd(x, y) upper_bound(all(x), y)
typedef pair <ll, ll> pi;
typedef pair <ll, pi> spi;
typedef pair <pi, pi> dpi;
inline ll rand(ll x, ll y) { ++y; return (rng() % (y-x)) + x; } //inclusivesss
ll n, A[MAXN/3], tt /*sum[MAXN/3][31][2]*/, anss;
int sum[MAXN/3][31][2];
vector <int> v[MAXN/3];
void dfs(ll x, ll p)
{
ll ans = 0;
for(auto i : v[x])
{
if(i==p) continue;
dfs(i, x);
FOR(j,0,31)
{
if((A[x] & (1<<j)))
{
sum[x][j][1] += sum[i][j][0];
sum[x][j][0] += sum[i][j][1];
}
else
{
sum[x][j][0] += sum[i][j][0];
sum[x][j][1] += sum[i][j][1];
}
}
}
// cout << x << "\n";
// FOR(i,0,5) { cout << sum[x][i][1] << ' ' << sum[x][i][0] << '\n'; }
for(auto i : v[x])
{
if(i==p) continue;
FOR(j,0,31)
{
if((A[x] & (1<<j)))
{
sum[x][j][1] -= sum[i][j][0];
sum[x][j][0] -= sum[i][j][1];
}
else
{
sum[x][j][0] -= sum[i][j][0];
sum[x][j][1] -= sum[i][j][1];
}
}
FOR(j,0,31)
{
ans += (((ll)sum[i][j][1] * (ll)sum[x][j][0]) * (1<<j));
ans += (((ll)sum[i][j][0] * (ll)sum[x][j][1]) * (1<<j));
// ans += max(((sum[i][j][1] * sum[x][j][0]) * (1<<j)), ((sum[i][j][0] * sum[x][j][1]) * (1<<j))
}
FOR(j,0,31)
{
if((A[x] & (1<<j)))
{
sum[x][j][1] += sum[i][j][0];
sum[x][j][0] += sum[i][j][1];
}
else
{
sum[x][j][0] += sum[i][j][0];
sum[x][j][1] += sum[i][j][1];
}
}
}
ans /= 2;
// ans += A[x];
FOR(i,0,31) { sum[x][i][1] += ((A[x] & (1<<i)) != 0); sum[x][i][0] += !((A[x] & (1<<i))!=0); }
FOR(i,0,31) ans += (sum[x][i][1] * (1<<i));
anss += ans;
return;
}
int main()
{
cin >> n;
FOR(i,1,n+1)
{
cin >> A[i];
}
FOR(i,1,n)
{
ll a, b;
cin >> a >> b;
v[a].pb(b);
v[b].pb(a);
}
dfs(1, 1);
cout << anss << "\n";
}
/*
3
1 2 3
1 2
2 3
*/
/*
*
5
2 3 4 2 1
1 2
1 3
3 4
3 5
*
*/
/*
*
*
6
5 4 1 3 3 3
3 1
3 5
4 3
4 2
2 6
*
* */
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
4 ms |
2684 KB |
Output is correct |
2 |
Correct |
3 ms |
2680 KB |
Output is correct |
3 |
Correct |
4 ms |
2808 KB |
Output is correct |
4 |
Correct |
6 ms |
3000 KB |
Output is correct |
5 |
Correct |
5 ms |
2936 KB |
Output is correct |
6 |
Incorrect |
197 ms |
40248 KB |
Output isn't correct |
7 |
Incorrect |
188 ms |
40284 KB |
Output isn't correct |
8 |
Incorrect |
183 ms |
32504 KB |
Output isn't correct |
9 |
Incorrect |
191 ms |
31684 KB |
Output isn't correct |
10 |
Incorrect |
233 ms |
30940 KB |
Output isn't correct |