Submission #623410

#TimeUsernameProblemLanguageResultExecution timeMemory
623410radalMergers (JOI19_mergers)C++17
10 / 100
74 ms40808 KiB
#include <bits/stdc++.h>
#pragma GCC target("sse,sse2,avx2")
#pragma GCC optimize("unroll-loops,O2")
#define rep(i,l,r) for (int i = l; i < r; i++)
#define repr(i,r,l) for (int i = r; i >= l; i--)
#define X first
#define Y second
#define all(x) (x).begin() , (x).end()
#define pb push_back
#define endl '\n'
#define debug(x) cerr << #x << " : " << x << endl;
using namespace std;
typedef long long ll;
typedef long double ld;
typedef pair<int,int> pll;
constexpr int N = 5e5+10,mod = 998244353,inf = 1e9+10,sq = 700;
inline int mkay(int a,int b){
    if (a+b >= mod) return a+b-mod;
    if (a+b < 0) return a+b+mod;
    return a+b;
}
 
inline int poww(int a,int k){
    if (k < 0) return 0;
    int z = 1;
    while (k){
        if (k&1) z = 1ll*z*a%mod;
        a = 1ll*a*a%mod;
        k /= 2;
    } 
    return z; 
}
vector<int> adj[N],col[N];
int par[N][20],h[N],T,tin[N],calc[N],dp[N];
int odd[N]; 

void pre(int v,int p){
    par[v][0] = p;
    tin[v] = T++;
    for (int u : adj[v]){
        if (u == p) continue;
        h[u] = h[v]+1;
        pre(u,v);
    }
}
bool cmp(int u,int v){
    return (tin[u] < tin[v]);
}

int lca(int u,int v){
    if (h[u] < h[v]) swap(u,v);
    repr(i,19,0){
        if ((1 << i) <= h[u]-h[v]) u = par[u][i];
    }
    if (u == v) return u;
    repr(i,19,0){
        if (par[u][i] != par[v][i]){
            v = par[v][i];
            u = par[u][i];
        }
    }
    return par[v][0];
}
void dfs(int v,int p){
    int t[4] = {0,0,0,0};
    for (int u : adj[v]){
        if (u != p){
            dfs(u,v);
            calc[v] += calc[u];
            /*if (v == 1){
                debug(u);
                debug(calc[u]);
                debug(dp[u]);
                debug(odd[u]);
            }*/
            if (odd[u]){
                if (dp[u]-1) t[3]++;
                else t[0]++;
            }
            else if (!calc[u]){
                if (dp[u]) t[2]++;
                else t[1]++;
            }
            dp[v] += dp[u];
        }
    }
 /*   if (v == 1){
        debug(t[0]);
        debug(t[1]);
        debug(t[2]);
        debug(t[3]);
    }*/
    t[2] %= 2;
    if (2*t[3] >= t[1]){
        t[3] -= t[1]/2;
        t[1] %= 2;
    }
    else{
        t[1] -= 2*t[3];
        t[3] = 0;
    }
    if (t[3]){
        t[3]--;
        if (t[1] && t[2]){
            t[1]--;
            t[2]--;
        }
        else if (t[0] && t[1]){
            t[0]--;
            t[1]--;
            dp[v]--;
        }
        else if (t[1]){
            if (t[1] >= 2){
                t[1] -= 2;
            }
            if (t[1] == 1){
                odd[v] = 1;
                return;
            }
            else{
                odd[v] = 1;
                return;
            }
        }
        else if (t[2] && t[0]){
            dp[v]--;
            t[0]--;
            t[2]--;

        }
        else if (t[0]){
            if (t[0] == 1){
                dp[v]--;
                odd[v] = 1;
                return;
            }
            t[0] -= 2;
            dp[v] -= 2;
        }
        else if (t[2]){
            dp[v]--;
            odd[v] = 1;
            t[2]--;
            return;
        }
        else{
            odd[v] = 1;
            return;
        }
    }
    if (t[2]){
        if (t[1] >= 2){
            t[1] -= 2;
            t[2] = 0;
            dp[v]++;
        }
        else if (t[1] == 1){
            if (t[0]){
                t[0]--;
                t[1]--;
                t[2]--;
            }
            else{
                odd[v] = 1;
                dp[v]++;
                return;
            }
        }
        else{
            if (t[0] >= 2){
                t[0] -= 2;
                t[2]--;
                dp[v]--;
            }
            if (t[0] == 1){
                odd[v] = 1;
                return;
            }
            dp[v]++;
            odd[v] = 1;
            return;
        }
    }
    if (t[0] == t[1]){
        return;
    }
    if (t[0] < t[1]){
        t[1] -= t[0];
        dp[v] += (t[1]+1)/2;
        odd[v] = t[1]%2;
        return;
    }
    t[0] -= t[1];
    dp[v] -= t[0]/2;
    odd[v] = t[0]%2;
}
int main(){
    ios :: sync_with_stdio(0); cin.tie(0);
    int n,k;
    cin >> n >> k;
    rep(i,1,n){
        int u,v;
        cin >> u >> v;
        adj[u].pb(v);
        adj[v].pb(u);
    }
    rep(i,1,n+1){
        int c;
        cin >> c;
        col[c].pb(i);
    }
    rep(i,1,k+1)  sort(all(col[i]),cmp);
    pre(1,0);
    rep(j,1,20){
        rep(i,2,n+1)
            par[i][j] = par[par[i][j-1]][j-1];
    }
    rep(i,1,k+1){
        int sz = col[i].size();
        if (sz < 2) continue;
        rep(j,1,sz){
            calc[col[i][j]]++;
            calc[col[i][j-1]]++;
            calc[lca(col[i][j],col[i][j-1])] -= 2;
        }
    }
    dfs(1,0);
    cout << dp[1];
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...