#include <bits/stdc++.h>
using namespace std;
int n;
vector <int> g[200005];
int s[200005];
int sz[200005];
int cntpg[200005];
int dp[200005];
int mx[200005];
int ans[200005];
void dfssz(int u,int p1,int p2,int &cnt)
{
cnt++;
sz[u]=1;
for(int v:g[u])
{
if(v!=p1 && v!=p2)
{
dfssz(v,u,p2,cnt);
sz[u]+=sz[v];
}
}
}
int dfsfc(int u,int p1,int p2,int nn)
{
for(int v:g[u])
if(v!=p1 && v!=p2 && nn-sz[v]<=nn/2)
return dfsfc(v,u,p2,nn);
return u;
}
void updDp(int u,int p1,int p2)
{
vector <int> vals;
mx[u]=0;
for(int v:g[u])
{
if(v!=p1 && v!=p2)
{
mx[u]=max(mx[u],mx[v]);
vals.push_back(mx[v]);
}
}
sort(vals.begin(),vals.end());
reverse(vals.begin(),vals.end());
if(vals.empty())
dp[u]=s[u];
else
{
dp[u]=0;
int sum=0;
for(int i=0; i<(int)vals.size(); i++)
{
sum+=vals[i];
dp[u]=max(dp[u],sum-cntpg[u]*(i+1));
}
}
mx[u]=max(mx[u],dp[u]+cntpg[u]);
}
void dfsdp(int u,int p1,int p2)
{
for(int v:g[u])
{
if(v!=p1 && v!=p2)
{
cntpg[v]=cntpg[u]+s[u];
dfsdp(v,u,p2);
}
}
updDp(u,p1,p2);
}
void solve(int u,int p,int nn=0)
{
dfssz(u,0,p,nn);
int r=dfsfc(u,0,p,nn);
for(int radj:g[r])
if(radj!=p)
solve(radj,r);
cntpg[r]=0;
dfsdp(r,p,p);
ans[r]=dp[r];
}
int main()
{
cin>>n;
for(int i=1; i<n; i++)
{
int a,b;
cin>>a>>b;
g[a].push_back(b);
g[b].push_back(a);
}
for(int i=1; i<=n; i++)
{
char ch;
cin>>ch;
s[i]=ch-'0';
}
solve(1,0);
int maxim=0;
for(int i=1; i<=n; i++)
maxim=max(maxim,ans[i]);
cout<<maxim<<"\n";
return 0;
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
2 ms |
9564 KB |
Output is correct |
2 |
Runtime error |
750 ms |
524288 KB |
Execution killed with signal 9 |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
2 ms |
9564 KB |
Output is correct |
2 |
Runtime error |
750 ms |
524288 KB |
Execution killed with signal 9 |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
2 ms |
9564 KB |
Output is correct |
2 |
Runtime error |
750 ms |
524288 KB |
Execution killed with signal 9 |
3 |
Halted |
0 ms |
0 KB |
- |