#include <bits/stdc++.h>
using namespace std;
int n;
vector<int> adj[100005];
int state[100005];
pair<int,int> dfs(int node, int parent) {
int cats = (state[node] == 1) ? 1 : 0;
int dogs = (state[node] == 2) ? 1 : 0;
for (int next : adj[node]) {
if (next != parent) {
auto [c, d] = dfs(next, node);
cats += c;
dogs += d;
}
}
return {cats, dogs};
}
int danger;
void solve(int node, int parent, int totalCats, int totalDogs) {
for (int next : adj[node]) {
if (next != parent) {
auto [childCats, childDogs] = dfs(next, node);
int parentCats = totalCats - childCats;
int parentDogs = totalDogs - childDogs;
if ((childCats > 0 && parentDogs > 0) || (childDogs > 0 && parentCats > 0)) {
danger++;
}
solve(next, node, totalCats, totalDogs);
}
}
}
int calculateDanger() {
danger = 0;
auto [totalCats, totalDogs] = dfs(1, -1);
if (totalCats == 0 || totalDogs == 0) return 0;
solve(1, -1, totalCats, totalDogs);
return danger;
}
void initialize(int N, vector<int> A, vector<int> B) {
n = N;
for (int i = 0; i < N - 1; i++) {
adj[A[i]].push_back(B[i]);
adj[B[i]].push_back(A[i]);
}
fill(state + 1, state + N + 1, 0);
}
int cat(int v) { state[v] = 1; return calculateDanger(); }
int dog(int v) { state[v] = 2; return calculateDanger(); }
int neighbor(int v) { state[v] = 0; return calculateDanger(); }