This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include <iostream>
#include <string>
#include <vector>
#include <queue>
#include <deque>
#include <stack>
#include <set>
#include <map>
#include <unordered_map>
#include <unordered_set>
#include <cstring>
#include <cmath>
#include <cstdlib>
#include <algorithm>
#include <random>
#include <iomanip>
#include <functional>
#include <cassert>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const int N = 5e5 + 7;
int par[N];
int ds[N];
int get(int a) {
return (a == par[a] ? a : par[a] = get(par[a]));
}
void join(int a, int b) {
a = get(a);
b = get(b);
if (a != b) {
if (ds[a] > ds[b]) swap(a, b);
par[a] = b;
ds[b] += ds[a];
}
}
vector <int> g[N];
mt19937_64 rnd(123);
int a[N], all[N], sz[N], bigChild[N], l[N], r[N];
ull go[N], sum[N];
vector <int> e;
void dfs(int u, int p) {
e.push_back(u);
l[u] = (int)e.size() - 1;
sum[u] += go[a[u]];
sz[u] = 1;
bigChild[u] = -1;
for (auto v : g[u]) {
if (v != p) {
dfs(v, u);
sum[u] += sum[v];
sz[u] += sz[v];
if (bigChild[u] == -1 || sz[v] > sz[bigChild[u]]) {
bigChild[u] = v;
}
}
}
r[u] = (int)e.size() - 1;
}
bool have[N];
ull need_sum = 0;
void jhfs(int u, int p, bool keep) {
for (auto v : g[u]) {
if (v != p && v != bigChild[u]) {
jhfs(v, u, 0);
}
}
if (bigChild[u] != -1) {
jhfs(bigChild[u], u, 1);
}
for (auto v : g[u]) {
if (v != p && v != bigChild[u]) {
for (int i = l[v]; i <= r[v]; ++i) {
int c = a[e[i]];
if (!have[c]) {
have[c] = 1;
need_sum += go[c] * all[c];
}
}
}
}
if (!have[a[u]]) {
have[a[u]] = 1;
need_sum += go[a[u]] * all[a[u]];
}
if (sum[u] != need_sum) {
join(u, p);
}
if (!keep) {
need_sum = 0;
for (int i = l[u]; i <= r[u]; ++i) {
int c = a[e[i]];
have[c] = 0;
}
}
}
int k;
int main() {
ios_base::sync_with_stdio(false); cin.tie(0);
#ifdef LOCAL
freopen("input.txt", "r", stdin);
#endif
int n;
cin >> n >> k;
for (int i = 0; i + 1 < n; ++i) {
int u, v;
cin >> u >> v;
--u, --v;
g[u].push_back(v);
g[v].push_back(u);
}
for (int i = 0; i < n; ++i) {
cin >> a[i];
--a[i];
++all[a[i]];
}
for (int i = 0; i < k; ++i) go[i] = rnd();
for (int i = 0; i < n; ++i) {
par[i] = i;
ds[i] = 1;
}
dfs(0, -1);
jhfs(0, -1, 0);
vector <int> deg(n);
for (int i = 0; i < n; ++i) {
for (int j : g[i]) {
if (get(i) != get(j)) {
++deg[get(i)];
}
}
}
int cnt = 0;
for (int i = 0; i < n; ++i) if (deg[i] == 1) ++cnt;
cout << (cnt + 1) / 2 << '\n';
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |