#include <iostream>
#include <vector>
#include <algorithm>
#include <numeric>
using namespace std;
int N, K;
vector<int> dist;
vector<vector<int>> adj;
vector<int> group_of;
vector<vector<int>> members;
vector<pair<int,int>> lcatree;
vector<int> first;
int lcatree_i = 0;
const int lcatree_shift = 1<<19;
vector<int> mergeuntil;
vector<bool> penible;
vector<int> uf;
int root(int u) {
return uf[u] == u ? u : uf[u] = root(uf[u]);
}
void merge(int u, int v) {
if (dist[u] > dist[v]) {
swap(u, v);
}
uf[root(v)] = root(u);
}
void dfs(int u) {
first[u] = lcatree_i;
lcatree[lcatree_shift+lcatree_i++] = make_pair(dist[u], u);
for (int v : adj[u]) {
adj[v].erase(find(adj[v].begin(), adj[v].end(), u));
dist[v] = dist[u]+1;
dfs(v);
lcatree[lcatree_shift+lcatree_i++] = make_pair(dist[u], u);
}
}
int dfs2(int u) {
int ret = mergeuntil[u];
for (int v : adj[u]) {
int r = dfs2(v);
if (dist[r] < dist[ret]) ret = r;
if (r != v) {
merge(v, u);
}
}
return ret;
}
int lca(int u, int v) {
int l = lcatree_shift+first[u], r = lcatree_shift+first[v]+1;
pair<int,int> ret(1e9, 1e9);
for (; l < r; l >>= 1, r >>= 1) {
if (l&1) {
ret = min(ret, lcatree[l++]);
}
if (r&1) {
ret = min(ret, lcatree[--r]);
}
}
return ret.second;
}
int dfs3(int u) {
int ret = 0;
for (int v : adj[u]) {
ret += dfs3(v);
}
ret = max(ret, (int)penible[u]);
return ret;
}
signed main() {
ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
cin >> N >> K;
dist.resize(N);
adj.resize(N);
first.resize(N);
group_of.resize(N);
members.resize(K);
lcatree.resize(1<<20);
uf.resize(N);
mergeuntil.resize(N);
penible.resize(N);
iota(mergeuntil.begin(), mergeuntil.end(), 0);
iota(uf.begin(), uf.end(), 0);
int u, v;
for (int i = 0; i < N-1; i++) {
cin >> u >> v, u--, v--;
adj[u].push_back(v);
adj[v].push_back(u);
}
for (int i = 0; i < N; i++) {
cin >> group_of[i], group_of[i]--;
members[group_of[i]].push_back(i);
}
dfs(0);
for (int i = lcatree_shift-1; i > 0; i--) {
lcatree[i] = min(lcatree[2*i], lcatree[2*i+1]);
}
for (int g = 0; g < K; g++) {
int a = members[g][0];
for (auto m : members[g]) {
a = lca(a, m);
}
for (auto m : members[g]) {
mergeuntil[m] = a;
}
}
dfs2(0);
bool plus1 = false;
int lcapenible = -1;
for (int i = 1; i < N; i++) {
if (root(i) == i) {
if (lcapenible == -1) lcapenible = i;
else lcapenible = lca(lcapenible, i);
penible[i] = true;
}
}
for (int i = 1; i < N; i++) {
if (penible[i]) {
if (lca(lcapenible, i) == i) {
plus1 = true;
break;
}
}
}
int ans = dfs3(0)+plus1;
cout << ((ans+1)>>1) << '\n';
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
4 ms |
8532 KB |
Output is correct |
2 |
Runtime error |
12 ms |
17108 KB |
Execution killed with signal 11 |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
4 ms |
8532 KB |
Output is correct |
2 |
Runtime error |
12 ms |
17108 KB |
Execution killed with signal 11 |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
4 ms |
8532 KB |
Output is correct |
2 |
Runtime error |
12 ms |
17108 KB |
Execution killed with signal 11 |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Runtime error |
43 ms |
34052 KB |
Execution killed with signal 11 |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
4 ms |
8532 KB |
Output is correct |
2 |
Runtime error |
12 ms |
17108 KB |
Execution killed with signal 11 |
3 |
Halted |
0 ms |
0 KB |
- |