Submission #216687

#TimeUsernameProblemLanguageResultExecution timeMemory
216687combi1k1Capital City (JOI20_capital_city)C++14
100 / 100
398 ms46944 KiB
#include<bits/stdc++.h>

using namespace std;

#define ll  long long
#define ld  double

#define sz(x)   (int)x.size()
#define all(x)  x.begin(),x.end()

#define pb  emplace_back
#define X   first
#define Y   second

const int   N   = 2e5 + 5;

typedef pair<int,int>   ii;

int p[N];
int s[N];

int init(int n) {
    iota(p + 1,p + 1 + n,1);
    fill(s + 1,s + 1 + n,1);
    return  1;
}
int lead(int x) {
    return p[x] == x ? x : p[x] = lead(p[x]);
}
int join(int x,int y)   {
    x = lead(x);
    y = lead(y);

    if (x == y) return  0;
    if (s[x] < s[y])
        swap(x,y);
    p[y] = x;
    s[x] += s[y];

    return  1;
}
int c[N];
int a[N];
int b[N];

vector<int> S[N];
vector<int> g[N];

int nCh[N], pos[N];
int led[N], par[N];
int arr[N], tot = 0;

void dfs(int u,int p)   {
    par[u] = p;
    nCh[u] = 1;

    for(int v : g[u])   if (v != p) {
        dfs(v,u);
        nCh[u] += nCh[v];
    }
}
void hld(int u,int ok)  {
    if (ok) led[u] = u;
    else    led[u] = led[par[u]];

    pos[u] = ++tot;
    arr[tot] = u;

    int B = 0;

    for(int v : g[u])   if (v != par[u])
        if (nCh[B] < nCh[v])
            B = v;

    if (B)  hld(B,0);

    for(int v : g[u])   if (v != par[u] && v != B)
        hld(v,1);
}
bool is_anc(int u,int v)    {   return  pos[u] <= pos[v] && pos[u] + nCh[u] >= pos[v] + nCh[v]; }

int lca(int u,int v)    {
    while (!is_anc(led[u],v))   u = par[led[u]];
    while (!is_anc(led[v],u))   v = par[led[v]];

    if (pos[u] > pos[v])
        swap(u,v);

    assert(is_anc(u,v));

    return  u;
}

int tr[N << 1];
int ok[N];

vector<int> vec[N];

bool have[N];
bool need[N];

int main()  {
    ios_base::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);

    int n;  cin >> n;
    int k;  cin >> k;

    for(int i = 1 ; i < n ; ++i)    {
        cin >> a[i];
        cin >> b[i];
    }
    for(int i = 1 ; i <= n ; ++i)
        cin >> c[i];

    init(n);

    for(int i = 1 ; i < n ; ++i)
        if (c[a[i]] == c[b[i]])
            join(a[i],b[i]);

    for(int i = 1 ; i < n ; ++i)
        if (c[a[i]] != c[b[i]]) {
            int x = lead(a[i]);
            int y = lead(b[i]);

            g[x].pb(y);
            g[y].pb(x);
        }

    for(int i = 1 ; i <= n ; ++i)
        if (i == p[i])
            S[c[i]].pb(i);

    dfs(lead(1),0);
    hld(lead(1),1);

    auto upd = [&](int p,int v) {
        for(tr[p += tot - 1] = v ; p > 1 ; p >>= 1)
            tr[p >> 1] = min(tr[p],tr[p ^ 1]);
    };
    auto get = [&](int l,int r) {
        l += tot - 1;
        r += tot;

        int ans = 1e9;

        for(; l < r ; l >>= 1, r >>= 1) {
            if (l & 1)  ans = min(ans,tr[l++]);
            if (r & 1)  ans = min(ans,tr[--r]);
        }
        return  ans;
    };
    for(int i = 1 ; i <= tot ; ++i)
        upd(i,1e9);

    for(int i = 1 ; i <= k ; ++i)
        sort(all(S[i]),[&](int x,int y) {
            return  pos[x] < pos[y];
        });

    vector<int> perms(k);

    iota(all(perms),1);
    sort(all(perms),[&](int x,int y)    {
        int a = lca(S[x][0],S[x].back());
        int b = lca(S[y][0],S[y].back());

        return  pos[a] < pos[b];
    });

    for(int i : perms)  {
        int R = lca(S[i][0],S[i].back());
        int P = pos[R];

        for(int x : S[i])   {
            while (!is_anc(led[x],R))   {
                P = min(get(pos[led[x]],pos[x]),P);
                x = par[led[x]];
            }
            P = min(get(pos[R],pos[x]),P);
        }
        for(int x : S[i])
            upd(pos[x],P);

        if (P == pos[S[i][0]])
            ok[i] = 1;
    }
    int ans = 1e9;

    for(int i = 1 ; i <= k ; ++i)   if (ok[i])  {
        need[i] = 1;

        queue<int>  qu;
        stack<int>  st;

        for(int x : S[i])
            qu.push(x),
            have[x] = 1;

        int cur = 0;

        while (qu.size())   {
            int u = qu.front();
            qu.pop();
            st.push(u);

            if (get(pos[u],pos[u]) != pos[S[i][0]]) {
                cur = 1e9;
                while (qu.size())   {
                    st.push(qu.front());
                    qu.pop();
                }
                break;
            }

            if (u != S[i][0])
                u = par[u];

            if(!have[u])    {
                have[u] = 1;
                qu.push(u);

                if(!need[c[u]]) {
                    need[c[u]] = 1; ++cur;

                    for(int x : S[c[u]])
                        if(!have[x])    {
                            have[x] = 1;
                            qu.push(x);
                        }
                }
            }
        }
        while (st.size())   {
            int u = st.top();   st.pop();
            have[u] = 0;
            need[c[u]] = 0;
        }
        if (ans > cur)
            ans = cur;
    }
    cout << ans << endl;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...