Submission #898819

#TimeUsernameProblemLanguageResultExecution timeMemory
898819juliany2Capital City (JOI20_capital_city)C++17
41 / 100
3008 ms70228 KiB
#include<bits/stdc++.h>
using namespace std;
using ll = long long;
#define all(x) (x).begin(), (x).end()

template<class T> struct ST {
    static constexpr T ID = {(int) 1e9, 0}; // or whatever ID
    inline T comb(T a, T b) { return min(a, b); } // or whatever function

    int sz;
    vector<T> t;

    void init(int _sz, T val = ID) {
        t.assign((sz = _sz) * 2, ID);
    }
    void init(vector<T> &v) {
        t.resize((sz = v.size()) * 2);
        for (int i = 0; i < sz; ++i)
            t[i + sz] = v[i];
        for (int i = sz - 1; i; --i)
            t[i] = comb(t[i * 2], t[(i * 2) | 1]);
    }
    void upd(int i, T x) {
        for (t[i += sz] = x; i > 1; i >>= 1)
            t[i >> 1] = comb(t[i], t[i ^ 1]);
    }
    T query(int l, int r) {
        T ql = ID, qr = ID;
        for (l += sz, r += sz + 1; l < r; l >>= 1, r >>= 1) {
            if (l & 1) ql = comb(ql, t[l++]);
            if (r & 1) qr = comb(t[--r], qr);
        }
        return comb(ql, qr);
    }
};

struct DSU {
    vector<int> e;
    DSU(int sz) { e = vector<int>(sz + 1, -1); }

    int get(int x) { return e[x] < 0 ? x : e[x] = get(e[x]); }

    bool same_set(int a, int b) { return get(a) == get(b); }

    int size(int x) { return -e[get(x)]; }

    bool unite(int x, int y) {
        x = get(x), y = get(y);
        if (x == y) return false;
        if (e[x] > e[y]) swap(x, y);
        e[x] += e[y]; e[y] = x;
        return true;
    }
};

const int N = 2e5 + 7, L = 20;
int n, k;
vector<int> adj[N], col[N], topo, comp;
set<int> active;
ST<array<int, 2>> st;
int c[N], lift[N][L], dep[N], sz[N], head[N], pos[N], start[N], who[N], timer;
int tin[N], tout[N], reach[N];
bool vis[N], done[N];

void dfs(int v = 1, int p = 0) {
    tin[v] = ++timer;
    sz[v] = 1;
    lift[v][0] = p;
    for (int i = 1; i < L; i++)
        lift[v][i] = lift[lift[v][i - 1]][i - 1];

    for (int &u : adj[v]) {
        if (u != p) {
            dep[u] = dep[v] + 1;
            dfs(u, v);
            sz[v] += sz[u];
            if (adj[v][0] == p || sz[v] > sz[adj[v][0]])
                swap(u, adj[v][0]);
        }
    }

    tout[v] = timer;
}

void dfs_hld(int v = 1, int p = 0) {
    pos[v] = timer++;
    for (int u : adj[v]) {
        if (u != p) {
            head[u] = (u == adj[v][0] ? head[v] : u);
            dfs_hld(u, v);
        }
    }
}

int lca(int u, int v) {
    if (dep[u] > dep[v])
        swap(u, v);

    for (int i = L - 1; ~i; --i)
        if (dep[v] - (1 << i) >= dep[u])
            v = lift[v][i];

    if (u == v)
        return u;

    for (int i = L - 1; ~i; --i)
        if (lift[v][i] != lift[u][i])
            v = lift[v][i], u = lift[u][i];
    return lift[u][0];
}

void dfs1(int a);

void process(int a, int b) {
    while (active.lower_bound(a) != active.end()) {
        int x = *active.lower_bound(a);
        if (x > b)
            break;

        active.erase(x);
        if (!vis[who[x]])
            dfs1(who[x]);
    }
}

void query(int a, int b) {
    for (; head[a] != head[b]; b = lift[head[b]][0]) {
        if (dep[b] > reach[head[b]]) {
            process(pos[head[b]], pos[b]);
            reach[head[b]] = dep[b];
        }
    }

    process(pos[a], pos[b]);
}

void dfs1(int a) {
    vis[a] = 1;
    for (int v : col[a])
        query(start[a], v);

    topo.push_back(a);
}

void dfs2(int a) {
    vis[a] = 1;
    comp.push_back(a);
    for (int v : col[a]) {
        while (st.query(tin[v], tout[v])[0] <= dep[v]) {
            int x = st.query(tin[v], tout[v])[1];
            st.upd(tin[x], {(int) 1e9, 0});
            if (!vis[c[x]])
                dfs2(c[x]);
        }
    }
}

int main() {
    cin.tie(0)->sync_with_stdio(false);

    cin >> n >> k;

    for (int i = 1; i < n; i++) {
        int u, v;
        cin >> u >> v;

        adj[u].push_back(v);
        adj[v].push_back(u);
    }

    for (int i = 1; i <= n; i++) {
        cin >> c[i];
        col[c[i]].push_back(i);
    }

    timer = 0;
    dfs();

    for (int i = 1; i <= k; i++) {
        start[i] = col[i][0];
        for (int x : col[i]) {
            start[i] = lca(start[i], x);
        }
    }


    timer = 1;
    head[1] = 1;
    dfs_hld();

    for (int i = 1; i <= n; i++)
        who[pos[i]] = c[i];

    for (int i = 1; i <= n; i++)
        active.insert(i);

    memset(reach, -1, sizeof(reach));

    for (int i = 1; i <= k; i++)
        if (!vis[i])
            dfs1(i);

    reverse(all(topo));

    st.init(n + 1);
    memset(vis, 0, sizeof(vis));
    memset(head, 0, sizeof(head));

    for (int i = 1; i <= n; i++)
        st.upd(tin[i], {dep[start[c[i]]], i});

    DSU dsu(n + 1);

    int ans = k - 1;
    for (int v : topo) {
        if (!vis[v]) {
            comp.clear();
            dfs2(v);

            int cnt = 0;
            for (int x : comp)
                for (int u : col[x])
                    head[u] = v, cnt++;

            for (int x : comp)
                for (int s : col[x])
                    for (int t : adj[s])
                        if (head[t] == v)
                            dsu.unite(s, t);

            if (dsu.size(col[comp[0]][0]) == cnt)
                ans = min(ans, (int) comp.size() - 1);
        }
    }

    cout << ans << '\n';

    return 0;
}

#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...