Submission #740832

#TimeUsernameProblemLanguageResultExecution timeMemory
740832josanneo22Capital City (JOI20_capital_city)C++17
100 / 100
249 ms52824 KiB
#include<bits/stdc++.h> 
using namespace std; 
typedef long long ll;

const int N = 2e5 + 7;
vector <int> g[N],gr[N],rg[N],t;
int c[N],sz[N],cmp[N],bad[N];
bool used[N],ok[N];
 
void add(int u, int v) {
    gr[u].push_back(v);
    rg[v].push_back(u);
} 
int tin[N], tout[N];
 
int timer = 0;
 
void init(int u, int p) {
  tin[u] = timer++;
  for (auto v : g[u]) {
    if (v != p) {
      init(v, u);
    }
  }
  tout[u] = timer++;
}
 
bool anc(int a, int b) {
  return tin[a] <= tin[b] && tout[b] <= tout[a];
}
 
void dfs(int u, int p = -1) {
  if (u != bad[c[u]]) {
    add(c[u], c[p]);
  }
  for (auto v : g[u]) {
    if (v != p) {
      dfs(v, u);
    }
  }
}
 
void get_topsort(int u) {
  used[u] = true;
  for (auto v : gr[u]) {
    if (!used[v]) {
      get_topsort(v);
    }
  }
  t.push_back(u);
}
 
int cnt = 0;
 
void jhfs(int u) {
  used[u] = true;
  cmp[u] = cnt;
  ++sz[cnt];
  for (auto v : rg[u]) {
    if (!used[v]) {
      jhfs(v);
    }
  }
}
 
int main() {
    ios_base::sync_with_stdio(false); cin.tie(0);
    int n, k; cin >> n >> k;
    for (int i = 0; i + 1 < n; ++i) {
        int u, v;
        cin >> u >> v;
        --u, --v;
        g[u].push_back(v);
        g[v].push_back(u);
    }
    vector <vector <int>> who(k);
    for (int i = 0; i < n; ++i) {
        cin >> c[i];
        --c[i];
        who[c[i]].push_back(i);
    }
    init(0, -1); 
    for (int i = 0; i < k; ++i) {
        int v = who[i].front();
        for (auto w : who[i])
            if (anc(w, v)) 
                v = w; 
        bool ok = true;
        for (auto w : who[i]) ok &= anc(v, w); 

        if (ok) bad[i] = v;
        else bad[i] = -1;
    }
    dfs(0);
    for (int i = 0; i < k; ++i) {
        if (!used[i]) get_topsort(i); 
    }
    reverse(t.begin(), t.end());
    fill(used, used + k, false);   
    for (int u : t) {
        if (!used[u]) {
            jhfs(u);
            ++cnt;
        }
    }
    fill(ok, ok + cnt, true);
    for (int i = 0; i < k; ++i) {
        for (int j : gr[i]) {
            if (cmp[i] != cmp[j]) {
                ok[cmp[i]] = false;
            }
        }
    }
    int ans = n;
    for (int i = 0; i < cnt; ++i) {
        if (ok[i]) ans = min(ans, sz[i]);
    }
    cout << ans - 1 << '\n';
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...