# | 제출 시각 | 아이디 | 문제 | 언어 | 결과 | 실행 시간 | 메모리 |
---|---|---|---|---|---|---|---|
755758 | gnu | The Xana coup (BOI21_xanadu) | C++14 | 0 ms | 0 KiB |
이 제출은 이전 버전의 oj.uz에서 채점하였습니다. 현재는 제출 당시와는 다른 서버에서 채점을 하기 때문에, 다시 제출하면 결과가 달라질 수도 있습니다.
#include <iostream>
#include <vector>
using namespace std;
using ll = long long;
vector<vector<vector<ll>>> dp;
vector<int> state;
void dfs(int u, int p, vector<vector<int>>& gp)
{
ll sum1 = 0, cnt1 = 0, mn1 = 1e18, sum2 = 0, cnt2 = 0, mn2 = 1e18;
for (int v : gp[u]) {
if (v == p) continue;
dfs(v, u, gp);
if (dp[v][1][0] < dp[v][0][0]) {
sum1 += dp[v][1][0];
++cnt1;
}
else {
sum1 += dp[v][0][0];
}
mn1 = min(mn1, abs(dp[v][1][0] - dp[v][0][0]));
if (dp[v][1][1] < dp[v][0][1]) {
sum2 += dp[v][1][1];
++cnt2;
}
else {
sum2 += dp[v][0][1];
}
mn2 = min(mn2, abs(dp[v][1][1] - dp[v][0][1]));
}
if (gp[u].size() == 1 && u != 0) {
if (state[u]) {
dp[u][1][0] = 1;
dp[u][0][0] = 1e18;
dp[u][1][1] = 1e18;
dp[u][0][1] = 0;
}
else {
dp[u][1][0] = 1e18;
dp[u][0][0] = 0;
dp[u][1][1] = 1;
dp[u][0][1] = 1e18;
}
return;
}
if (((cnt1 & 1) && state[u]) || (!(cnt1 & 1) && !state[u])) {
dp[u][0][0] = sum1;
}
else {
dp[u][0][0] = sum1 + mn1;
}
state[u] = !state[u];
if (((cnt2 & 1) && state[u]) || (!(cnt2 & 1) && !state[u])) {
dp[u][1][0] = sum2+1;
}
else {
dp[u][1][0] = sum2 + mn2+1;
}
if (((cnt1 & 1) && state[u]) || (!(cnt1 & 1) && !state[u])) {
dp[u][0][1] = sum1;
}
else {
dp[u][0][1] = sum1 + mn1;
}
state[u] = !state[u];
if (((cnt2 & 1) && state[u]) || (!(cnt2 & 1) && !state[u])) {
dp[u][1][1] = sum2+1;
}
else {
dp[u][1][1] = sum2 + mn2+1;
}
}
void solve()
{
int n; cin >> n;
vector<vector<int>> gp(n);
state = vector<int>(n);
dp = vector<vector<vector<ll>>>(n, vector<vector<ll>>(2, vector<ll>(2)));
for (int i = 0; i < n - 1; ++i) {
int u, v; cin >> u >> v; --u; --v;
gp[u].push_back(v);
gp[v].push_back(u);
}
for (int& x : state) {cin >> x;x=!x;}
dfs(0, -1, gp);
if (min(dp[0][0][0], dp[0][1][0]) <= n) {
cout << min(dp[0][0][0], dp[0][1][0]) << '\n';
}
else {
cout <<"impossible\ n";
}
}
int main()
{
//int t; cin >> t; while (t--)
solve();
}