# | Time | Username | Problem | Language | Result | Execution time | Memory |
---|---|---|---|---|---|---|---|
561398 | AGE | Deblo (COCI18_deblo) | C++14 | 247 ms | 65536 KiB |
This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include<bits/stdc++.h>
#define F first
#define S second
#define pb push_back
#define ll long long
using namespace std;
const int N=1e5+5,M=2e3+10,mod=100000000000031;
vector<int>adj[N];
long long summ(long long x,long long y){
return ((x%mod)+(y%mod))%mod;
}
long long mult(long long x,long long y){
return ((x%mod)*(y%mod))%mod;
}
ll dp[N][25],num[N],a[N],ans[N][25],xx[25];
int aa[N][25];
void dfs2(int node,int par){
for(auto x:adj[node]){
if(x==par)
continue;
dfs2(x,node);
}
for(auto x:adj[node])
num[node]+=num[x];
num[node]++;
}
void dfs(int node,int par){
for(auto x:adj[node]){
if(x==par)
continue;
dfs(x,node);
}
for(int j=0;j<=24;j++){
if(aa[node][j]==1){
for(auto x:adj[node]){
if(x==par)
continue;
dp[node][j]+=num[x]-dp[x][j];
}
if(adj[node].size()!=1||node==1)
dp[node][j]++;
}
else
for(auto x:adj[node])
dp[node][j]+=dp[x][j];
}
if(adj[node].size()==1&&node!=1)
for(int j=0;j<=24;j++)
dp[node][j]=aa[node][j];
}
int n;
void dfs3(int node,int par){
if(node!=1){
for(int j=0;j<=24;j++)
xx[j]=ans[par][j];
for(int j=0;j<=24;j++)
if(aa[par][j]==1)
xx[j]-=(num[node]-dp[node][j]);
else
xx[j]-=dp[node][j];
for(int j=0;j<=24;j++){
if(aa[node][j]==1)
ans[node][j]+=(n-num[node])-xx[j];
else
ans[node][j]+=xx[j];
}
for(int j=0;j<=24;j++)
ans[node][j]+=dp[node][j];
}
for(auto x:adj[node]){
if(x==par)
continue;
dfs3(x,node);
}
}
main()
{
cin>>n;
for(int i=1;i<=n;i++){
int x;
cin>>x;
a[i]=x;
for(int j=0;j<=24;j++)
if(x&(1<<j))
aa[i][j]=1;
}
for(int i=1;i<n;i++){
int x,y;
cin>>x>>y;
adj[x].pb(y);
adj[y].pb(x);
}
dfs2(1,0);
dfs(1,0);
for(int j=0;j<=24;j++)
ans[1][j]=dp[1][j];
dfs3(1,0);
int answer=0;
for(int i=1;i<=n;i++)
for(int j=0;j<=24;j++)
if(ans[i][j]==0)
continue;
else
answer+=ans[i][j]*(1<<j);
int sum=0;
for(int i=1;i<=n;i++)
sum+=a[i];
answer-=sum;
answer/=2;
answer+=sum;
cout<<answer<<endl;
return 0;
}
Compilation message (stderr)
# | Verdict | Execution time | Memory | Grader output |
---|---|---|---|---|
Fetching results... |