Submission #983478

#TimeUsernameProblemLanguageResultExecution timeMemory
983478OAleksaThe Xana coup (BOI21_xanadu)C++14
30 / 100
45 ms21636 KiB
#include <bits/stdc++.h>
#define f first
#define s second
using namespace std;
const int N = 1e5 + 69;
const int inf = 1e9 + 69;
int n, dp[N][2][2], a[N];
vector<int> g[N];
void dfs(int v, int p) {
	vector<int> ch;
	for (auto u : g[v]) {
		if (u == p)
			continue;
		dfs(u, v);
		ch.push_back(u);
	}
	if (ch.empty()) {
		dp[v][a[v]][0] = 0;
		dp[v][a[v] ^ 1][1] = 1;
		dp[v][a[v]][1] = dp[v][a[v] ^ 1][0] = inf;
	}
	else {
		int uk1 = 0, uk2 = 0;
		int mn1 = inf, mn2 = inf, c = 0;
		for (auto u : ch) {
			uk1 += min(dp[u][0][0], dp[u][0][1]);
			if (dp[u][0][0] <= dp[u][0][1])
				mn1 = min(mn1, dp[u][0][1] - dp[u][0][0]);
			if (dp[u][0][1] <= dp[u][0][0]) {
				mn2 = min(mn2, dp[u][0][0] - dp[u][0][1]);
				c += 1;
			}
		}
		if (c % 2 == 0) {
			dp[v][a[v]][0] = uk1;
			dp[v][a[v] ^ 1][0] = uk1 + min(mn1, mn2);
		}
		else {
			dp[v][a[v] ^ 1][0] = uk1;
			dp[v][a[v]][0] = uk1 + min(mn1, mn2);
		}
		mn1 = mn2 = inf;
		c = 0;
		for (auto u : ch) {
			if (a[u] == 1) {
				uk2 += min(dp[u][a[u]][0], dp[u][a[u]][1]);
				if (dp[u][a[u]][0] <= dp[u][a[u]][1])
					mn1 = min(mn1, dp[u][a[u]][1] - dp[u][a[u]][0]);
				if (dp[u][a[u]][1] <= dp[u][a[u]][0]) {
					mn2 = min(mn2, dp[u][a[u]][0] - dp[u][a[u]][1]);
					c += 1;
				}
			}
			else {
				uk2 += min(dp[u][a[u] ^ 1][0], dp[u][a[u] ^ 1][1]);
				if (dp[u][a[u] ^ 1][0] <= dp[u][a[u] ^ 1][1])
					mn1 = min(mn1, dp[u][a[u] ^ 1][1] - dp[u][a[u] ^ 1][0]);
				if (dp[u][a[u] ^ 1][1] <= dp[u][a[u] ^ 1][0]) {
					mn2 = min(mn2, dp[u][a[u] ^ 1][0] - dp[u][a[u] ^ 1][1]);
					c += 1;
				}
			}
		}
		if (c % 2 == 0) {
			dp[v][a[v]][1] = uk2 + min(mn1, mn2) + 1;
			dp[v][a[v] ^ 1][1] = uk2 + 1; 
		}
		else {
			dp[v][a[v]][1] = uk2 + 1;
			dp[v][a[v] ^ 1][1] = uk2 + min(mn1, mn2) + 1;
		}
	}
}
int main() {
  ios::sync_with_stdio(false);
  cin.tie(0);
  cout.tie(0);
  int tt = 1;
  //cin >> tt;
  while (tt--) {        
	cin >> n;
	for (int i = 1;i <= n - 1;i++) {
		int a, b;
		cin >> a >> b;
		g[a].push_back(b);
		g[b].push_back(a);
	}
	for (int i = 1;i <= n;i++)
		cin >> a[i];
	dfs(1, 0);
	int ans = min(dp[1][0][0], dp[1][0][1]);
	if (ans >= inf)
		cout << "impossible\n";
	else
		cout << ans << '\n'; 
  }
  return 0; 
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...