Submission #256185

#TimeUsernameProblemLanguageResultExecution timeMemory
256185fedoseevtimofeyMergers (JOI19_mergers)C++14
100 / 100
1196 ms79460 KiB
#include <iostream>
#include <string>
#include <vector>
#include <queue>
#include <deque>
#include <stack>
#include <set>
#include <map>
#include <unordered_map>
#include <unordered_set>
#include <cstring>
#include <cmath>
#include <cstdlib>
#include <algorithm>
#include <random>
#include <iomanip>
#include <functional>
#include <cassert>
 
using namespace std;
 
typedef long long ll;
typedef unsigned long long ull;

const int N = 5e5 + 7;
int par[N];
int ds[N];

int get(int a) {
  return (a == par[a] ? a : par[a] = get(par[a]));
}

void join(int a, int b) {
  a = get(a);
  b = get(b);
  if (a != b) {
    if (ds[a] > ds[b]) swap(a, b);
    par[a] = b;
    ds[b] += ds[a];
  }
}   

vector <int> g[N];
mt19937_64 rnd(123);

int a[N], all[N], sz[N], bigChild[N], l[N], r[N];

ull go[N], sum[N];
vector <int> e;

void dfs(int u, int p) {
  e.push_back(u);
  l[u] = (int)e.size() - 1;
  sum[u] += go[a[u]];
  sz[u] = 1;
  bigChild[u] = -1;
  for (auto v : g[u]) {
    if (v != p) {
      dfs(v, u);
      sum[u] += sum[v];
      sz[u] += sz[v];
      if (bigChild[u] == -1 || sz[v] > sz[bigChild[u]]) {
        bigChild[u] = v;
      }
    }
  }
  r[u] = (int)e.size() - 1;
}

bool have[N];
ull need_sum = 0;

void jhfs(int u, int p, bool keep) {
  for (auto v : g[u]) {
    if (v != p && v != bigChild[u]) {
      jhfs(v, u, 0);
    }
  } 
  if (bigChild[u] != -1) {
    jhfs(bigChild[u], u, 1);
  }

  for (auto v : g[u]) {
    if (v != p && v != bigChild[u]) {
      for (int i = l[v]; i <= r[v]; ++i) {
        int c = a[e[i]];
        if (!have[c]) {
          have[c] = 1;
          need_sum += go[c] * all[c];
        }
      }
    }
  }
  if (!have[a[u]]) {
    have[a[u]] = 1;
    need_sum += go[a[u]] * all[a[u]];
  }
  if (sum[u] != need_sum) {
    join(u, p);
  }
  if (!keep) {
    need_sum = 0;
    for (int i = l[u]; i <= r[u]; ++i) {
      int c = a[e[i]];
      have[c] = 0;
    }
  }
}
  
int k;
 
int main() {
  ios_base::sync_with_stdio(false); cin.tie(0);
#ifdef LOCAL
  freopen("input.txt", "r", stdin);
#endif
  int n;
  cin >> n >> k;
  for (int i = 0; i + 1 < n; ++i) {
    int u, v;
    cin >> u >> v;
    --u, --v;
    g[u].push_back(v);
    g[v].push_back(u);
  }
  for (int i = 0; i < n; ++i) {
    cin >> a[i];
    --a[i];
    ++all[a[i]];
  }
  for (int i = 0; i < k; ++i) go[i] = rnd();
  for (int i = 0; i < n; ++i) {
    par[i] = i;
    ds[i] = 1;
  }
  dfs(0, -1);
  jhfs(0, -1, 0);
  vector <int> deg(n);
  for (int i = 0; i < n; ++i) {
    for (int j : g[i]) {
      if (get(i) != get(j)) {
        ++deg[get(i)];
      }
    }
  }
  int cnt = 0;
  for (int i = 0; i < n; ++i) if (deg[i] == 1) ++cnt;
  cout << (cnt + 1) / 2 << '\n';
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...