Submission #759036

#TimeUsernameProblemLanguageResultExecution timeMemory
759036YENGOYANThe Xana coup (BOI21_xanadu)C++17
100 / 100
111 ms34880 KiB
//#include <iostream>
//#include <vector>
//using namespace std;
//using ll = long long;
//vector<vector<vector<ll>>> dp;
//vector<int> state;
//void dfs(int u, int p, vector<vector<int>>& gp)
//{
//	ll sum1 = 0, cnt1 = 0, mn1 = 1e18, sum2 = 0, cnt2 = 0, mn2 = 1e18;
//	for (int v : gp[u]) {
//		if (v == p) continue;
//
//		dfs(v, u, gp);
//		if (dp[v][1][0] <= dp[v][0][0]) {
//			sum1 += dp[v][1][0];
//			++cnt1;
//		}
//		else {
//			sum1 += dp[v][0][0];
//		}
//		mn1 = min(mn1, abs(dp[v][1][0] - dp[v][0][0]));
//		if (dp[v][1][1] <= dp[v][0][1]) {
//			sum2 += dp[v][1][1];
//			++cnt2;
//		}
//		else {
//			sum2 += dp[v][0][1];
//		}
//		mn2 = min(mn2, abs(dp[v][1][1] - dp[v][0][1]));
//
//	}
//	if (gp[u].size() == 1 && u != 0) {
//		if (!state[u]) {
//			dp[u][1][0] = 1;
//			dp[u][0][0] = 1e18;
//			dp[u][1][1] = 1e18;
//			dp[u][0][1] = 0;
//		}
//		else {
//			dp[u][1][0] = 1e18;
//			dp[u][0][0] = 0;
//			dp[u][1][1] = 1;
//			dp[u][0][1] = 1e18;
//		}
//		return;
//	}
//	if (cnt1 % 2 != state[u]) {
//		dp[u][0][0] = sum1;
//	}
//	else {
//		dp[u][0][0] = sum1 + mn1;
//	}
//	state[u] = !state[u];
//	if (cnt2 % 2 != state[u]) {
//		dp[u][1][0] = sum2 + 1;
//	}
//	else {
//		dp[u][1][0] = sum2 + mn2 + 1;
//	}
//	if (cnt1 % 2 != state[u]) {
//		dp[u][0][1] = sum1;
//	}
//	else {
//		dp[u][0][1] = sum1 + mn1;
//	}
//	state[u] = !state[u];
//	if (cnt2 % 2 != state[u]) {
//		dp[u][1][1] = sum2 + 1;
//	}
//	else {
//		dp[u][1][1] = sum2 + mn2 + 1;
//	}
//}
//
//void solve()
//{
//	string a(45, '0');
//	string b(34, '1');
//	string res = a;
//	res += b;
//	res += "10111110001000100000011010001001100001011100101000011101111100010000000011000101110011110";
//	int boris = 0, hayk = 0;
//	for (int i = 0; i < res.size(); ++i) {
//		if (res[i] == '0') ++hayk;
//		else ++boris;
//	}
//	cout << hayk << " " << boris << "\n";
//}
//int main()
//{
//	//int t; cin >> t; while (t--)
//	solve();
//}
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
using ll = long long;
vector<vector<vector<ll>>> dp;
vector<int> state;
void dfs(int u, int p, vector<vector<int>>& gp)
{
	ll sum1 = 0, cnt1 = 0, mn1 = 1e18, sum2 = 0, cnt2 = 0, mn2 = 1e18;
	for (int v : gp[u]) {
		if (v == p) continue;

		dfs(v, u, gp);
		if (dp[v][1][0] < dp[v][0][0]) {
			sum1 += dp[v][1][0];
			++cnt1;
		}
		else {
			sum1 += dp[v][0][0];
		}
		mn1 = min(mn1, abs(dp[v][1][0] - dp[v][0][0]));
		if (dp[v][1][1] < dp[v][0][1]) {
			sum2 += dp[v][1][1];
			++cnt2;
		}
		else {
			sum2 += dp[v][0][1];
		}
		mn2 = min(mn2, abs(dp[v][1][1] - dp[v][0][1]));

	}
	if (gp[u].size() == 1 && u != 0) {
		if (state[u]) {
			dp[u][1][0] = 1;
			dp[u][0][0] = 1e18;
			dp[u][1][1] = 1e18;
			dp[u][0][1] = 0;
		}
		else {
			dp[u][1][0] = 1e18;
			dp[u][0][0] = 0;
			dp[u][1][1] = 1;
			dp[u][0][1] = 1e18;
		}
		return;
	}
	if (((cnt1 & 1) && state[u]) || (!(cnt1 & 1) && !state[u])) {
		dp[u][0][0] = sum1;
	}
	else {
		dp[u][0][0] = sum1 + mn1;
	}
	state[u] = !state[u];
	if (((cnt2 & 1) && state[u]) || (!(cnt2 & 1) && !state[u])) {
		dp[u][1][0] = sum2 + 1;
	}
	else {
		dp[u][1][0] = sum2 + mn2 + 1;
	}


	if (((cnt1 & 1) && state[u]) || (!(cnt1 & 1) && !state[u])) {
		dp[u][0][1] = sum1;
	}
	else {
		dp[u][0][1] = sum1 + mn1;
	}

	state[u] = !state[u];
	if (((cnt2 & 1) && state[u]) || (!(cnt2 & 1) && !state[u])) {
		dp[u][1][1] = sum2 + 1;
	}
	else {
		dp[u][1][1] = sum2 + mn2 + 1;
	}
	long long inf = 1e18;
	dp[u][0][0] = min(dp[u][0][0] * 1ll, inf);
	dp[u][0][1] = min(dp[u][0][1] * 1ll, inf);
	dp[u][1][0] = min(dp[u][1][0] * 1ll, inf);
	dp[u][1][1] = min(dp[u][1][1] * 1ll, inf);
}

void solve()
{
	int n; cin >> n;
	vector<vector<int>> gp(n);
	state = vector<int>(n);
	dp = vector<vector<vector<ll>>>(n, vector<vector<ll>>(2, vector<ll>(2)));
	for (int i = 0; i < n - 1; ++i) {
		int u, v; cin >> u >> v; --u; --v;
		gp[u].push_back(v);
		gp[v].push_back(u);
	}
	for (int& x : state) { cin >> x; }
	dfs(0, -1, gp);
	if (min(dp[0][0][0], dp[0][1][0]) <= n) {
		cout << min(dp[0][0][0], dp[0][1][0]) << '\n';
	}
	else {
		cout << "impossible\n";
	}


}
int main()
{
	//int t; cin >> t; while (t--)
	solve();
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...