Submission #400959

#TimeUsernameProblemLanguageResultExecution timeMemory
400959HegdahlThe Xana coup (BOI21_xanadu)C++17
100 / 100
90 ms32560 KiB
#include <bits/stdc++.h>
#define ar array

using namespace std;
using ll = long long;

const ll INF = 1e9;

const int mxN = 1e5;
vector<int> g[mxN];
int a[mxN];

// 0b(pressed)(state)
ar<ll, 4> solve(int cur, int prv) {
  int kid_cnt = 0;
  for (int nxt : g[cur]) {
    if (nxt == prv) continue;
    ++kid_cnt;
  }
  if (!kid_cnt) {
    ar<ll, 4> ret {INF, INF, INF, INF};
    if (a[cur]) {
      ret[0b01] = 0;
      ret[0b10] = 1;
    } else {
      ret[0b00] = 0;
      ret[0b11] = 1;
    }

    /*
    cerr << cur << '\n';
    cerr << kid_cnt << '\n';
    for (int s = 0; s < 1<<2; ++s)
      cerr << bitset<2>(s) << ' ' << ret[s] << '\n';
    cerr << '\n'; */

    return ret;
  }

  vector<ll> kid_costs[4];
  for (int nxt : g[cur]) {
    if (nxt == prv) continue;
    auto r = solve(nxt, cur);

    for (int s = 0; s < 1<<2; ++s)
      kid_costs[s].push_back(r[s]);
  }

  /*
  for (int s = 0; s < 1<<2; ++s) {
    cerr << bitset<2>(s) << ": ";
    for (int i = 0; i < kid_cnt; ++i)
      cerr << kid_costs[s][i] << ' ';
    cerr << '\n';
  } // */

  ar<ll, 4> ret {INF, INF, INF, INF};

  // not pressed
  // can only take 0b00 or 0b10

  vector<int> idxs(kid_cnt);
  iota(idxs.begin(), idxs.end(), 0);
  sort(idxs.begin(), idxs.end(), [&](int i, int j) {
      return kid_costs[0b010][i]-kid_costs[0b00][i]
      < kid_costs[0b010][j]-kid_costs[0b00][j];
      });

  ll s0x = 0;
  for (int i = 0; i < kid_cnt; ++i)
    s0x += kid_costs[0b00][i];

  ret[0b00 ^ a[cur]] = min(ret[0b00 ^ a[cur]], s0x);
  for (int swp = 0; swp < kid_cnt; ++swp) {
    s0x -= kid_costs[0b00][idxs[swp]];
    s0x += kid_costs[0b10][idxs[swp]];

    int msk = 0b00 ^ (a[cur]^((swp+1)&1));
    //cerr << "s0x " << bitset<2>(msk) << ' ' << swp << ' ' << s0x << '\n';
    ret[msk] = min(ret[msk], s0x);
  }

  // pressed
  // can only take 0b01 or 0b11

  sort(idxs.begin(), idxs.end(), [&](int i, int j) {
      return kid_costs[0b011][i]-kid_costs[0b01][i]
      < kid_costs[0b011][j]-kid_costs[0b01][j];
      });

  ll s1x = 1;
  for (int i = 0; i < kid_cnt; ++i)
    s1x += kid_costs[0b01][i];

  ret[0b11 ^ a[cur]] = min(ret[0b11 ^ a[cur]], s1x);
  for (int swp = 0; swp < kid_cnt; ++swp) {
    s1x -= kid_costs[0b01][idxs[swp]];
    s1x += kid_costs[0b11][idxs[swp]];

    int msk = 0b11 ^ (a[cur]^((swp+1)&1));

    //cerr << "s1x " << bitset<2>(msk) << ' ' << swp << ' ' << s1x << '\n';
    ret[msk] = min(ret[msk], s1x);
  }

  //cerr << cur << '\n';
  //cerr << kid_cnt << '\n';
  /*
  for (int s = 0; s < 1<<2; ++s)
    cerr << bitset<2>(s) << ' ' << ret[s] << '\n';
  cerr << '\n'; // */

  return ret;
}

int main() {
  ios::sync_with_stdio(0);cin.tie(0);

  int n; cin >> n;
  for (int nn = 0; nn < n-1; ++nn) {
    int i, j; cin >> i >> j; --i, --j;
    g[i].push_back(j);
    g[j].push_back(i);
  }

  for (int i = 0; i < n; ++i)
    cin >> a[i];

  auto ans_ar = solve(0, -1);
  ll ans = min(ans_ar[0b00], ans_ar[0b10]);
  if (ans > n) {
    cout << "impossible\n";
    return 0;
  }
  cout << ans << '\n';
}

#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...