#include "bits/stdc++.h"
using namespace std;
#define all(a) (a).begin(), (a).end()
#define ll long long
const int MAX_N = 2e5 + 5;
const int MOD = 1e9 + 7;
const int INF = 1e9;
const int LOG = 30;
string s;
vector<int> adj[MAX_N];
int dp[MAX_N][2];
void dfs(int u, int p)
{
dp[u][1] = 0;
for (int v : adj[u]) if (v != p)
{
dfs(v, u);
dp[u][1] += dp[v][1];
}
dp[u][1] = max(dp[u][1] - (s[u] == '1' ? 1 : 0), (s[u] == '1' ? 1 : 0));
int sum1 = 0;
int sum2 = 0;
int sum3 = 0;
for (int v : adj[u]) if (v != p)
{
sum1 = max(sum1, dp[v][1]);
sum2 += dp[v][1];
sum3 = max(sum3, dp[v][0]);
}
dp[u][0] = max({(s[u] == '1' ? 1 + sum1 : -INF), sum3, sum2 - (s[u] == '1' ? 1 : 0)});
}
void solve()
{
int n;
cin >> n;
for (int i = 1; i < n; i++)
{
int u, v;
cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
cin >> s;
s = "?" + s;
dfs(1, 0);
cout << dp[1][0] << "\n";
}
int main()
{
cin.tie(NULL); cout.tie(NULL); ios_base::sync_with_stdio(false);
int tc = 1;
// cin >> tc;
while (tc--) solve();
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |