이 제출은 이전 버전의 oj.uz에서 채점하였습니다. 현재는 제출 당시와는 다른 서버에서 채점을 하기 때문에, 다시 제출하면 결과가 달라질 수도 있습니다.
#include <bits/stdc++.h>
#define ar array
using namespace std;
using ll = long long;
const ll INF = 1e9;
const int mxN = 1e5;
vector<int> g[mxN];
int a[mxN];
// 0b(pressed)(state)
ar<ll, 4> solve(int cur, int prv) {
int kid_cnt = 0;
for (int nxt : g[cur]) {
if (nxt == prv) continue;
++kid_cnt;
}
if (!kid_cnt) {
ar<ll, 4> ret {INF, INF, INF, INF};
if (a[cur]) {
ret[0b01] = 0;
ret[0b10] = 1;
} else {
ret[0b00] = 0;
ret[0b11] = 1;
}
/*
cerr << cur << '\n';
cerr << kid_cnt << '\n';
for (int s = 0; s < 1<<2; ++s)
cerr << bitset<2>(s) << ' ' << ret[s] << '\n';
cerr << '\n'; */
return ret;
}
vector<ll> kid_costs[4];
for (int nxt : g[cur]) {
if (nxt == prv) continue;
auto r = solve(nxt, cur);
for (int s = 0; s < 1<<2; ++s)
kid_costs[s].push_back(r[s]);
}
/*
for (int s = 0; s < 1<<2; ++s) {
cerr << bitset<2>(s) << ": ";
for (int i = 0; i < kid_cnt; ++i)
cerr << kid_costs[s][i] << ' ';
cerr << '\n';
} // */
ar<ll, 4> ret {INF, INF, INF, INF};
// not pressed
// can only take 0b00 or 0b10
vector<int> idxs(kid_cnt);
iota(idxs.begin(), idxs.end(), 0);
sort(idxs.begin(), idxs.end(), [&](int i, int j) {
return kid_costs[0b010][i]-kid_costs[0b00][i]
< kid_costs[0b010][j]-kid_costs[0b00][j];
});
ll s0x = 0;
for (int i = 0; i < kid_cnt; ++i)
s0x += kid_costs[0b00][i];
ret[0b00 ^ a[cur]] = min(ret[0b00 ^ a[cur]], s0x);
for (int swp = 0; swp < kid_cnt; ++swp) {
s0x -= kid_costs[0b00][idxs[swp]];
s0x += kid_costs[0b10][idxs[swp]];
int msk = 0b00 ^ (a[cur]^((swp+1)&1));
//cerr << "s0x " << bitset<2>(msk) << ' ' << swp << ' ' << s0x << '\n';
ret[msk] = min(ret[msk], s0x);
}
// pressed
// can only take 0b01 or 0b11
sort(idxs.begin(), idxs.end(), [&](int i, int j) {
return kid_costs[0b011][i]-kid_costs[0b01][i]
< kid_costs[0b011][j]-kid_costs[0b01][j];
});
ll s1x = 1;
for (int i = 0; i < kid_cnt; ++i)
s1x += kid_costs[0b01][i];
ret[0b11 ^ a[cur]] = min(ret[0b11 ^ a[cur]], s1x);
for (int swp = 0; swp < kid_cnt; ++swp) {
s1x -= kid_costs[0b01][idxs[swp]];
s1x += kid_costs[0b11][idxs[swp]];
int msk = 0b11 ^ (a[cur]^((swp+1)&1));
//cerr << "s1x " << bitset<2>(msk) << ' ' << swp << ' ' << s1x << '\n';
ret[msk] = min(ret[msk], s1x);
}
//cerr << cur << '\n';
//cerr << kid_cnt << '\n';
/*
for (int s = 0; s < 1<<2; ++s)
cerr << bitset<2>(s) << ' ' << ret[s] << '\n';
cerr << '\n'; // */
return ret;
}
int main() {
ios::sync_with_stdio(0);cin.tie(0);
int n; cin >> n;
for (int nn = 0; nn < n-1; ++nn) {
int i, j; cin >> i >> j; --i, --j;
g[i].push_back(j);
g[j].push_back(i);
}
for (int i = 0; i < n; ++i)
cin >> a[i];
auto ans_ar = solve(0, -1);
ll ans = min(ans_ar[0b00], ans_ar[0b10]);
if (ans > n) {
cout << "impossible\n";
return 0;
}
cout << ans << '\n';
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |