#include <bits/stdc++.h>
using namespace std;
const int Nmax = 2e5 + 5;
string is_plant;
vector<int> v[Nmax];
bool used[Nmax];
int dp[Nmax];
int w[Nmax];
int ans = 0;
int n, All;
void dfs0(int node, int dad = 0)
{
w[node] = 1;
for(auto it : v[node])
if(!used[it] && it != dad)
{
dfs0(it, node);
w[node] += w[it];
}
}
int centroid(int node, int dad = 0)
{
int worst = All - w[node];
for(auto it : v[node])
if(!used[it] && it != dad)
{
int res = centroid(it, node);
if(res != -1) return res;
worst = max(worst, w[it]);
}
if(worst <= All/2) return node;
return -1;
}
void calc(int node, int dad = 0)
{
int sum = 0;
for(auto it : v[node])
if(!used[it] && it != dad)
{
calc(it, node);
sum += dp[it];
}
if(is_plant[node]) dp[node] = max(1, sum - 1);
else dp[node] = sum;
assert(dp[node] >= 0);
}
void case1(int node)
{
int sum = 0;
for(auto it : v[node])
if(!used[it])
sum += dp[it];
ans = max(ans, sum - is_plant[node]);
}
void case2(int node)
{
if(!is_plant[node]) return;
for(auto it : v[node])
if(!used[it])
ans = max(ans, dp[it] + 1);
}
void solve(int node)
{
dfs0(node);
All = w[node];
node = centroid(node);
assert(node != -1);
calc(node);
case1(node);
case2(node);
used[node] = 1;
for(auto it : v[node])
if(!used[it])
solve(it);
}
int main()
{
// freopen("input", "r", stdin);
cin.tie(0); cin.sync_with_stdio(false);
cin >> n;
int i;
for(i=1; i<n; ++i)
{
int x, y;
cin >> x >> y;
v[x].push_back(y);
v[y].push_back(x);
}
cin >> is_plant; is_plant = "#" + is_plant;
for(auto &it : is_plant) it -= '0';
solve(1);
cout << ans << '\n';
return 0;
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
4 ms |
4992 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
4 ms |
4992 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
4 ms |
4992 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |