Submission #880176

#TimeUsernameProblemLanguageResultExecution timeMemory
880176browntoadMergers (JOI19_mergers)C++14
100 / 100
1140 ms197080 KiB
#include<bits/stdc++.h>
using namespace std;

#define ll long long
#define int ll
#define FOR(i, a, b) for (int i = (a); i < (b); i++)
#define REP(i, n) FOR(i, 0, n)
#define REP1(i, n) FOR(i, 1, n+1)
#define RREP(i, n) for (int i = (n)-1; i >= 0; i--)
#define RREP1(i, n) for (int i = (n); i >= 1; i--)
#define pii pair<int, int>
#define f first
#define s second
#define pb push_back
#define ALL(x) (x).begin(), (x).end()
#define SZ(x) (int)((x).size())

const ll maxn = 5e5+5;
const ll inf = (1ll<<60);

int n, k;
vector<int> g[maxn], g2[maxn];
vector<int> col(maxn), par(maxn);

int fin(int a){
    if (a == par[a]) return a;
    return par[a] = fin(par[a]);
}
void merg(int a, int b){
    a = fin(a); b = fin(b);
    par[a] = b;
}

map<int, int> st[maxn]; // oo
vector<int> mx(maxn), crs(maxn), sz(maxn);
void dfs(int x, int prev){
    sz[x]++;
    pii cid = {-1, -1};
    int u;
    REP(i, SZ(g[x])){
        u = g[x][i];
        if (u == prev) continue;
        dfs(u, x);
        cid = max(cid, {sz[u], u});
        sz[x] += sz[u];
    }

    if (cid.s != -1) {
        st[x].swap(st[cid.s]);
        crs[x] = crs[cid.s];
    }
    REP(i, SZ(g[x])){
        u = g[x][i];
        if (u == prev || u == cid.s) continue;
        for (auto p:st[u]){
            if (st[x][p.f] == 0) crs[x]++;
            st[x][p.f] += p.s;
            if (st[x][p.f] == mx[p.f]) crs[x]--;
        }
        st[u].clear();
    }
    if (st[x][col[x]] == 0) crs[x]++;
    st[x][col[x]]++;
    if (st[x][col[x]] == mx[col[x]]) crs[x]--;

    if (crs[x] > 0) {
        merg(x, prev);
        //cout<<x<<' '<<prev<<endl;
    }


}



signed main(){
    ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
    cin>>n>>k;
    REP1(i, n) par[i] = i;
    REP(i, n-1){
        int u, v; cin>>u>>v;
        g[u].pb(v);
        g[v].pb(u);
    }
    REP1(i, n){
        cin>>col[i];
        mx[col[i]]++;
    }
    dfs(1, -1);
    vector<int> deg(n+1);
    REP1(i, n){
        for (auto u:g[i]) if (fin(u) != fin(i) && u < i) deg[fin(i)]++, deg[fin(u)]++; // so it only counts once
    }
    int lf = 0;
    REP1(i, n) if (deg[i] == 1) lf++;
    cout<<(lf/2)+(lf&1)<<endl;


}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...