Submission #332211

#TimeUsernameProblemLanguageResultExecution timeMemory
332211LifeHappen__수도 (JOI20_capital_city)C++14
100 / 100
789 ms41708 KiB
#include <bits/stdc++.h>

using namespace std;

#define ii pair<int, int>
#define fi first
#define se second
#define pb push_back
#define eb emplace_back

const int N = 2e5 + 5;
int n, k;
vector<int> ad[N], in[N];
int col[N];
ii ed[N];
int sz[N], dd[N], cnt, cnt_col[N], pd[N], vis[N];
int vis_col[N];
int res = 1e9;

void find_size(int u, int par) {
  sz[u] = 1;
  cnt++;
  for (int &v : ad[u]) {
    if(v != par && !dd[v]) {
      find_size(v, u);
      sz[u] += sz[v];
      //cerr << sz[v] << '\n';
    }
  }
}
int find_big(int u, int par) {
  for (int &v : ad[u]) {
    if(v != par && !dd[v] && sz[v] >= cnt / 2) {
      return find_big(v, u);
    }
  }
  return u;
}
void cen(int u) {
  cnt = 0;
  find_size(u, 0);
  u = find_big(u, 0);
  //cerr << u << ' ';
  int root = u;
  pd[u] = u;
  function<void(int, int)> dfs = [&](int u, int par) {
    vis[u] = root;
    for (auto v : ad[u]) {
      if(v != par && !dd[v]) {
        pd[v] = u;
        vis[v] = root;
        dfs(v, u);
      }
    }
  };
  dfs(u, 0);
  vector <int> ans;
  queue <int> q;
  q.push(u);
  int dem = -1;
  bool ok = 1;
  while(q.size() && ok) {
    int u = q.front();
    q.pop();
    dem++;
    while(ok) {
      if(vis[u] == 0) break;
      if(vis_col[col[u]] == 0) {
        dem -= cnt_col[col[u]];
        ans.pb(col[u]);
        vis_col[col[u]] = 1;
        for (int &v : in[col[u]]) {
          q.push(v);
          if(vis[v] != root) {
            ok = 0;
            break;
          }
        }
      }
      vis[u] = 0;
      u = pd[u];
      if(u == root) break;
    }
  }
  for (auto v : ans) vis_col[v] = 0;
  if(dem == 0) {
    res = min(res, (int)(ans.size()) - 1);
  }
  dd[u] = 1;
  for (auto v : ad[u]) {
    if(!dd[v]) {
      cen(v);
    }
  }
}

int32_t main() {
  ios_base::sync_with_stdio(false);
  cin.tie(0); cout.tie(0);

  cin >> n >> k;
  for (int i = 1; i < n; ++i) {
    int u, v;
    cin >> u >> v;
    ed[i] = {u, v};
    ad[u].pb(v);
    ad[v].pb(u);
  }
  for (int i = 1; i <= n; ++i) {
    cin >> col[i];
    cnt_col[col[i]]++;
    in[col[i]].pb(i);
  }
  cen(1);
  cout << res << '\n';
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...