제출 #678484

#제출 시각아이디문제언어결과실행 시간메모리
678484cig32The Xana coup (BOI21_xanadu)C++17
100 / 100
96 ms40024 KiB
#include "bits/stdc++.h" using namespace std; #define int long long const int MAXN = 3e5 + 10; const int MOD = 1e9 + 7; #define ll __int128 mt19937_64 rng((int)std::chrono::steady_clock::now().time_since_epoch().count()); int rnd(int x, int y) { int u = uniform_int_distribution<int>(x, y)(rng); return u; } ll read() { int x; cin >> x; return (ll)x; } long long bm(long long b, long long p) { if(p==0) return 1 % MOD; long long r = bm(b, p >> 1); if(p&1) return (((r*r) % MOD) * b) % MOD; return (r*r) % MOD; } long long inv(long long b) { return bm(b, MOD-2); } long long f[MAXN]; long long nCr(int n, int r) { long long ans = f[n]; ans *= inv(f[r]); ans %= MOD; ans *= inv(f[n-r]); ans %= MOD; return ans; } long long fib[MAXN], lucas[MAXN]; void precomp() { for(int i=0; i<MAXN; i++) f[i] = (i == 0 ? 1 % MOD : (f[i-1] * i) % MOD); lucas[0] = 2; lucas[1] = 1; for(int i=2; i<MAXN; i++) lucas[i] = (lucas[i-2] + lucas[i-1]) % MOD; fib[0] = 0; fib[1] = 1; for(int i=2; i<MAXN; i++) fib[i] = (fib[i-2] + fib[i-1]) % MOD; } int fastlog(int x) { return (x == 0 ? -1 : 64 - __builtin_clzll(x) - 1); } void gay(int i) { cout << "Case #" << i << ": "; } int csb(int l, int r, int k) { // count number of [l, r] such that i & 2^k > 0 if(l > r) return 0; if(l == 0) { int s = r / (1ll << (k+1)); // number of complete cycles int t = r % (1ll << (k+1)); int ans = s * (1ll << k); ans += (t >= (1ll << k) ? t - (1ll << k) + 1 : 0); return ans; } else return csb(0, r, k) - csb(0, l - 1, k); } int lis(vector<int> a) { int n = a.size(); int bucket[n+1]; for(int i=1; i<=n; i++) bucket[i] = 1e18; int ans = 1; for(int x: a) { auto it = lower_bound(bucket + 1, bucket +n +1, x); int d = distance(bucket, it); ans = max(ans, d); bucket[d] = min(bucket[d], x); } return ans; } int n; vector<int> adj[MAXN]; int ogname[MAXN]; int dp[MAXN][2][2]; void dfs(int node, int prv) { vector<int> vt; for(int x: adj[node]) { if(x != prv) { dfs(x, node); vt.push_back(x); } } int m = vt.size(); if(m == 0) { dp[node][ogname[node]][0] = 0; dp[node][(1 ^ ogname[node])][1] = 1; return; } int sub0[m][2], sub1[m][2]; for(int i=0; i<m; i++) for(int j=0; j<2; j++) sub0[i][j] = sub1[i][j] = 1e9; sub0[0][0] = dp[vt[0]][0][0]; sub0[0][1] = dp[vt[0]][0][1]; sub1[0][0] = dp[vt[0]][1][0]; sub1[0][1] = dp[vt[0]][1][1]; for(int i=1; i<m; i++) { sub0[i][0] = min(sub0[i-1][0] + dp[vt[i]][0][0], sub0[i-1][1] + dp[vt[i]][0][1]); sub0[i][1] = min(sub0[i-1][0] + dp[vt[i]][0][1], sub0[i-1][1] + dp[vt[i]][0][0]); sub1[i][0] = min(sub1[i-1][0] + dp[vt[i]][1][0], sub1[i-1][1] + dp[vt[i]][1][1]); sub1[i][1] = min(sub1[i-1][0] + dp[vt[i]][1][1], sub1[i-1][1] + dp[vt[i]][1][0]); } // Compute dp[node][0][1]: dp[vt[j]][1][x], sum of x mod 2 = (1 ^ ogname) dp[node][0][1] = sub1[m-1][(1 ^ ogname[node])] + 1; dp[node][1][0] = sub0[m-1][(1 ^ ogname[node])]; dp[node][0][0] = sub0[m-1][ogname[node]]; dp[node][1][1] = sub1[m-1][ogname[node]] + 1; /* for(int i=0; i<(1<<m); i++) { int st = ogname[node]; for(int j=0; j<m; j++) { int bit = (i & (1<<j)); if(bit > 0) st ^= 1; } // st = current state if(st) { // cur state = 1 // if final state = 0 (Toggle node) int sm = 0; for(int j=0; j<m; j++) { if(i & (1<<j)) sm += dp[vt[j]][1][1]; else sm += dp[vt[j]][1][0]; } dp[node][0][1] = min(dp[node][0][1], sm + 1); // if final state = 1 (Don't toggle node) sm = 0; for(int j=0; j<m; j++) { if(i & (1<<j)) sm += dp[vt[j]][0][1]; else sm += dp[vt[j]][0][0]; } dp[node][1][0] = min(dp[node][1][0], sm); } else { // cur state = 0 // if final state = 0 (Don't toggle node) int sm = 0; for(int j=0; j<m; j++) { if(i & (1<<j)) sm += dp[vt[j]][0][1]; else sm += dp[vt[j]][0][0]; } dp[node][0][0] = min(dp[node][0][0], sm); // if final state = 1 (Toggle node) sm = 0; for(int j=0; j<m; j++) { if(i & (1<<j)) sm += dp[vt[j]][1][1]; else sm += dp[vt[j]][1][0]; } dp[node][1][1] = min(dp[node][1][1], sm + 1); } } */ } void solve(int tc) { cin >> n; for(int i=2; i<=n; i++) { int a, b; cin >> a >> b; adj[a].push_back(b); adj[b].push_back(a); } for(int i=1; i<=n; i++) { for(int j=0; j<2; j++) { for(int k=0; k<2; k++) { dp[i][j][k] = 1e9; } } } for(int i=1; i<=n; i++) { cin >> ogname[i]; } dfs(1, -1); cout << (min(dp[1][0][0], dp[1][0][1]) >= 1e8 ? "impossible" : to_string(min(dp[1][0][0], dp[1][0][1]))) << "\n"; //cout << dp[4] } int32_t main() { precomp(); ios::sync_with_stdio(0); cin.tie(0); int t = 1; //cin >> t; for(int i=1; i<=t; i++) solve(i); }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...