#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace std;
using namespace __gnu_pbds;
typedef tree<int, null_type, less<int>, rb_tree_tag,tree_order_statistics_node_update> ordered_set;
#define scan(x) do{while((x=getchar())<'0'); for(x-='0'; '0'<=(_=getchar()); x=(x<<3)+(x<<1)+_-'0');}while(0)
char _;
#define complete_unique(a) a.erase(unique(a.begin(),a.end()),a.end())
#define all(a) a.begin(),a.end()
#define println printf("\n");
#define readln(x) getline(cin,x);
#define pb push_back
#define endl "\n"
#define INT_INF 0x3f3f3f3f
#define LL_INF 0x3f3f3f3f3f3f3f3f
#define MOD 1000000007
#define mp make_pair
#define fastio cin.tie(0); cin.sync_with_stdio(0);
#define MAXN 100005
typedef unsigned long long ull;
typedef long long ll;
typedef long double ld;
typedef unordered_map<int,int> umii;
typedef pair<int,int> pii;
typedef pair<double,double> pdd;
typedef pair<ll,ll> pll;
typedef pair<int,pii> triple;
typedef int8_t byte;
mt19937 g1(time(0));
int randint(int a, int b){return uniform_int_distribution<int>(a, b)(g1);}
ll randlong(ll a,ll b){return uniform_int_distribution<long long>(a, b)(g1);}
ll gcd(ll a, ll b){return b == 0 ? a : gcd(b, a % b);}
ll lcm(ll a, ll b){return a*b/gcd(a,b);}
ll fpow(ll b, ll exp, ll mod){if(exp == 0) return 1;ll t = fpow(b,exp/2,mod);if(exp&1) return t*t%mod*b%mod;return t*t%mod;}
ll divmod(ll i, ll j, ll mod){i%=mod,j%=mod;return i*fpow(j,mod-2,mod)%mod;}
int num_nodes,vals[MAXN],sz[MAXN],upd[MAXN],b,hb,mx,vv;
ll res,cnt[2],tmp[2];
pii ret;
vector<int> connections[MAXN];
int centroid(int node, int prev, int psz){
sz[node] = 1;
int mxsz = 0;
for(int check:connections[node]){
if(check == prev || upd[check] == b) continue;
centroid(check,node,psz);
mxsz = max(mxsz,sz[check]);
sz[node]+=sz[check];
}
mxsz = max(mxsz, psz-sz[node]);
if(mxsz < ret.first) ret = mp(mxsz, node);
return ret.second;
}
void dfs(int node, int prev, int rt, int has_b){
if(has_b){
if(vals[rt]&vv) res+=cnt[1]*vv;
else res+=(cnt[0]+1)*vv;
}else{
if(vals[rt]&vv) res+=(cnt[0]+1)*vv;
else res+=cnt[1]*vv;
}
tmp[has_b]++;
for(int check:connections[node]){
if(check == prev || upd[check] == b) continue;
dfs(check,node,rt,has_b^(vals[check]&vv?1:0));
}
}
void solve(int node, int msz){
ret = mp(INT_MAX,-1);
node = centroid(node,-1,msz);
upd[node] = b;
cnt[0] = cnt[1] = 0;
if(vals[node]&vv) res+=vv;
for(int check:connections[node]){
if(upd[check] == b) continue;
tmp[0] = tmp[1] = 0;
dfs(check,node,node,vals[check]&vv?1:0);
cnt[1]+=tmp[1], cnt[0]+=tmp[0];
}
for(int check:connections[node]){
if(upd[check] == b) continue;
solve(check,sz[check]);
}
}
int main(){
scanf("%d",&num_nodes);
for(int i=1; i<=num_nodes; i++){
scan(vals[i]);
mx = max(mx,vals[i]);
}
for(int i=1; i<num_nodes; i++){
int a,b; scan(a); scan(b);
connections[a].pb(b);
connections[b].pb(a);
}
hb = (int)log2(mx);
memset(upd,-1,sizeof upd);
for(b=0; b<=hb; b++){
vv = (1<<b);
solve(1,num_nodes);
}
printf("%lld\n",res);
}
/*
8
1 2 3 4 5 6 7 8
1 2
1 3
1 4
1 5
2 8
5 6
6 7
*/
Compilation message
deblo.cpp: In function 'int main()':
deblo.cpp:98:7: warning: ignoring return value of 'int scanf(const char*, ...)', declared with attribute warn_unused_result [-Wunused-result]
scanf("%d",&num_nodes);
~~~~~^~~~~~~~~~~~~~~~~
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
4 ms |
3064 KB |
Output is correct |
2 |
Correct |
4 ms |
3064 KB |
Output is correct |
3 |
Correct |
5 ms |
3064 KB |
Output is correct |
4 |
Correct |
13 ms |
3064 KB |
Output is correct |
5 |
Correct |
11 ms |
3152 KB |
Output is correct |
6 |
Execution timed out |
1081 ms |
14840 KB |
Time limit exceeded |
7 |
Execution timed out |
1079 ms |
14968 KB |
Time limit exceeded |
8 |
Execution timed out |
1073 ms |
8440 KB |
Time limit exceeded |
9 |
Execution timed out |
1064 ms |
7800 KB |
Time limit exceeded |
10 |
Execution timed out |
1071 ms |
7036 KB |
Time limit exceeded |