#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const ll inf = 1e12;
int n, a[100005];
vector <int> adj[100005];
ll dp[100005][2][2];
void dfs(int u, int par) {
for (auto v : adj[u]) {
if (v != par) {
dfs(v, u);
}
}
vector <int> child;
for (auto v : adj[u]) {
if (v != par) {
child.push_back(v);
}
}
array<array<ll, 2>, 2> cur;
cur[0][0] = 0;
cur[0][1] = 0;
cur[1][0] = inf;
cur[1][1] = inf;
for (int v : child) {
array<array<ll, 2>, 2> nxt;
for (int k = 0; k < 2; k++) {
for (int j = 0; j < 2; j++) {
nxt[k][j] = inf;
}
}
for (int k = 0; k < 2; k++) {
for (int j = 0; j < 2; j++) {
if (cur[k][j] >= inf) continue;
if (nxt[k][j] > cur[k][j] + dp[v][0][j]) {
nxt[k][j] = cur[k][j] + dp[v][0][j];
}
int nk = k ^ 1;
if (nxt[nk][j] > cur[k][j] + dp[v][1][j]) {
nxt[nk][j] = cur[k][j] + dp[v][1][j];
}
}
}
cur = nxt;
}
dp[u][0][a[u]] = min(inf, cur[0][0]);
dp[u][1][a[u]] = min(inf, 1 + cur[1][1]);
dp[u][0][a[u]^1] = min(inf, cur[0][1]);
dp[u][1][a[u]^1] = min(inf, 1 + cur[1][0]);
}
signed main() {
cin >> n;
for (int i = 1; i < n; i++) {
int u, v;
cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
for (int i = 1; i <= n; i++) cin >> a[i];
for (int i = 1; i <= n; i++) {
for (int j = 0; j < 2; j++) {
for (int k = 0; k < 2; k++) {
dp[i][j][k] = inf;
}
}
}
dfs(1, 0);
ll ans = min(dp[1][0][0], dp[1][1][0]);
if (ans >= inf) cout << "impossible";
else cout << ans;
return 0;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |