#include <bits/stdc++.h>
#define int long long
using namespace std;
int a;
vector<int> z[1000005];
int open[1000005];
int dp[1000005];
int dist[1000005];
int max1;
void dfs(int i,int par){
dist[i]=dist[par]+open[par];
if (open[i]){
dp[i]=1;
}
int pre=0;
for (auto p:z[i]){
if (p==par){
continue;
}
dfs(p,i);
pre+=dp[p];
}
int pl=0;
pre-=open[i];
max1=max(max1,pre+(dist[i]>0));
dp[i]=max(dp[i],pre);
// cout << i << " " << dp[i] << "\n";
}
signed main()
{
ios_base::sync_with_stdio(false);
cin.tie(NULL);
cin >> a;
for (int i=1;i<a;i++){
int x,y;
cin >> x >> y;
z[x].push_back(y);
z[y].push_back(x);
}
string s;
cin >> s;
s='#'+s;
for (int i=1;i<=a;i++){
open[i]=s[i]-'0';
// cout << open[i] << " " << "\n";
}
dfs(1,0);
for (int i=1;i<=a;i++){
max1=max(max1,dp[i]);
}
cout << max1 << "\n";
return 0;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |