Submission #399130

#TimeUsernameProblemLanguageResultExecution timeMemory
399130Jarif_RahmanThe Xana coup (BOI21_xanadu)C++17
100 / 100
110 ms28648 KiB
#include <bits/stdc++.h>
#define pb push_back
#define f first
#define sc second
using namespace std;
typedef long long int ll;
typedef string str;
const ll inf = 1e9;

int n;

vector<vector<int>> v;
vector<vector<vector<ll>>> dp;
vector<int> state;

void dfs(int nd, int ss){
    if(v[nd].size() == 1 && ss != -1){
        dp[nd][0][0] = (state[nd]?inf:0);
        dp[nd][0][1] = (state[nd]?1:inf);
        dp[nd][1][0] = (state[nd]?0:inf);
        dp[nd][1][1] = (state[nd]?inf:1);
        return;
    }
    for(int x: v[nd]) if(x != ss) dfs(x, nd);

    dp[nd][0][1] = 1;
    dp[nd][1][1] = 1;

    int cnt = 0;
    ll mn = inf;
    for(int x: v[nd]) if(x != ss){
        dp[nd][0][0] += min(dp[x][0][0], dp[x][0][1]);
        dp[nd][1][0] += min(dp[x][0][0], dp[x][0][1]);
        if(dp[x][0][0] > dp[x][0][1]) cnt++;
        mn = min(mn, abs(dp[x][0][0]-dp[x][0][1]));
    }
    if(state[nd]){
        if(cnt%2 == 0) dp[nd][0][0]+=mn;
        else dp[nd][1][0]+=mn;
    }
    else{
        if(cnt%2 == 0) dp[nd][1][0]+=mn;
        else dp[nd][0][0]+=mn;
    }

    cnt = 0;
    mn = inf;
    for(int x: v[nd]) if(x != ss){
        dp[nd][0][1] += min(dp[x][1][0], dp[x][1][1]);
        dp[nd][1][1] += min(dp[x][1][0], dp[x][1][1]);
        if(dp[x][1][0] > dp[x][1][1]) cnt++;
        mn = min(mn, abs(dp[x][1][0]-dp[x][1][1]));
    }
    if(!state[nd]){
        if(cnt%2 == 0) dp[nd][0][1]+=mn;
        else dp[nd][1][1]+=mn;
    }
    else{
        if(cnt%2 == 0) dp[nd][1][1]+=mn;
        else dp[nd][0][1]+=mn;
    }
    dp[nd][0][0] = min(dp[nd][0][0], inf);
    dp[nd][1][0] = min(dp[nd][1][0], inf);
    dp[nd][0][1] = min(dp[nd][0][1], inf);
    dp[nd][1][1] = min(dp[nd][1][1], inf);
}

int main(){
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cin >> n;
    v.resize(n);
    state.resize(n);
    dp.assign(n, vector<vector<ll>>(2, vector<ll>(2, 0)));
    for(int i = 0; i < n-1; i++){
        int a, b; cin >> a >> b; a--, b--;
        v[a].pb(b);
        v[b].pb(a);
    }
    for(int &x: state) cin >> x;
    dfs(0, -1);
    ll ans = min(dp[0][0][0], dp[0][0][1]);
    if(ans >= inf) cout << "impossible\n";
    else cout << ans << "\n";
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...