#include <iostream>
#include <vector>
#include <functional>
using namespace std;
using ll = long long;
vector<int> adj[500005];
int di[500005];
int in[500005], out[500005];
vector<int> ett;
int prf[1000005];
int sparse[500005][19];
int dep[500005];
vector<int> area[500005];
int N, K;
int deg[500005];
int main()
{
cin >> N >> K;
for (int i = 1; i < N; i++) {
int u, v; cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
function<void(int, int)> dfs = [&](int v, int p)
{
dep[v] = dep[p] + 1;
sparse[v][0] = p;
for (int i = 1; i < 19; i++) sparse[v][i] = sparse[sparse[v][i-1]][i-1];
in[v] = ett.size();
ett.push_back(v);
for (int i:adj[v]) {
if (i != p) {
dfs(i, v);
}
}
out[v] = ett.size();
ett.push_back(-v);
};
dfs(1, 0);
function<int(int, int)> lca = [&](int u, int v)
{
if (dep[u] > dep[v]) swap(u, v);
for (int i = 18; i >= 0; i--) {
if (dep[sparse[v][i]] >= dep[u]) v = sparse[v][i];
}
if (u == v) return u;
for (int i = 18; i >= 0; i--) {
if (sparse[u][i] != sparse[v][i]) {
u = sparse[u][i];
v = sparse[v][i];
}
}
return u;
};
function<int(int, int)> last = [&](int p, int v)
{
for (int i = 18; i >= 0; i--) {
if (dep[sparse[v][i]] > dep[p]) v = sparse[v][i];
}
return v;
};
function<void(int, int)> add_anc = [&](int p, int v)
{
prf[in[p]]++;
prf[in[v]+1]--;
};
function<void(int, int)> add_path = [&](int u, int v)
{
int p = lca(u, v);
if (p != u) add_anc(last(p, u), u);
if (p != v) add_anc(last(p, v), v);
};
for (int i = 1; i <= N; i++) {
int x; cin >> x;
area[x].push_back(i);
}
for (int i = 1; i <= K; i++) {
for (int j = 0; j < (int)area[i].size(); j++) {
add_path(area[i][j], area[i][(j+1)%area[i].size()]);
}
}
for (int i = 0; i < (int)ett.size(); i++) {
if (i) prf[i] += prf[i-1];
if (ett[i] > 0) di[ett[i]] += prf[i];
else di[-ett[i]] -= prf[i];
}
//di[i]: i와 자기 부모를 잇는 간선이 갈리는가
int ord = 0;
function<void(int, int, int)> dfs2 = [&](int v, int p, int num)
{
for (int i:adj[v]) {
if (i != p) {
if (di[i]) {
dfs2(i, v, num);
}
else {
ord++;
deg[num]++;
deg[ord]++;
dfs2(i, v, ord);
}
}
}
};
dfs2(1, 0, 0);
int ans = 0;
for (int i = 0; i <= ord; i++) ans += (deg[i] == 1);
if (ans == 0) cout << 0 << '\n';
else cout << (ans + 1) / 2 << '\n';
return 0;
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
14 ms |
23764 KB |
Output is correct |
2 |
Incorrect |
13 ms |
23792 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
14 ms |
23764 KB |
Output is correct |
2 |
Incorrect |
13 ms |
23792 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
14 ms |
23764 KB |
Output is correct |
2 |
Incorrect |
13 ms |
23792 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
166 ms |
38380 KB |
Output is correct |
2 |
Correct |
178 ms |
41472 KB |
Output is correct |
3 |
Correct |
17 ms |
24324 KB |
Output is correct |
4 |
Incorrect |
16 ms |
24332 KB |
Output isn't correct |
5 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
14 ms |
23764 KB |
Output is correct |
2 |
Incorrect |
13 ms |
23792 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |