#include "beechtree.h"
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
static inline int isz(const auto &x){ return (int)x.size(); }
vector<int> beechtree(int N, int M, vector<int> P, vector<int> C) {
vector<vector<int>> ch(N);
for (int i = 1; i < N; i++) ch[P[i]].push_back(i);
// postorder
vector<int> order;
order.reserve(N);
vector<int> st = {0}, it(N, 0), par(N, -1);
while (!st.empty()) {
int u = st.back();
if (it[u] < isz(ch[u])) {
int v = ch[u][it[u]++];
par[v] = u;
st.push_back(v);
} else {
st.pop_back();
order.push_back(u);
}
}
// subtree sizes
vector<int> sz(N, 1);
for (int u : order) {
long long s = 1;
for (int v : ch[u]) s += sz[v];
sz[u] = (int)s;
}
// Condition A + build z maps
vector<char> unique_ok(N, 1);
vector<vector<pair<int,int>>> zvec(N); // (color, z)
vector<unordered_map<int,int>> zmap(N); // color -> z
for (int u = 0; u < N; u++) {
unordered_map<int,int> seen; // color -> child
seen.reserve(isz(ch[u]) * 2 + 1);
for (int v : ch[u]) {
int col = C[v];
if (seen.find(col) != seen.end()) {
unique_ok[u] = 0;
} else {
seen[col] = v;
}
}
if (!unique_ok[u]) continue;
zmap[u].reserve(seen.size() * 2 + 1);
zvec[u].reserve(seen.size());
for (auto &kv : seen) {
int col = kv.first;
int v = kv.second;
int z = sz[v];
zmap[u][col] = z;
zvec[u].push_back({col, z});
}
}
auto getZ = [&](int node, int col)->int {
auto it = zmap[node].find(col);
return (it == zmap[node].end()) ? 0 : it->second;
};
auto dominates = [&](int A, int B)->bool {
// require sz[A] >= sz[B]
for (auto &pr : zvec[B]) {
int col = pr.first;
int zb = pr.second;
if (getZ(A, col) < zb) return false;
}
return true;
};
auto check_neighbor = [&](map<int,int> &mp, map<int,int>::iterator it)->bool {
if (it != mp.begin()) {
auto pit = prev(it);
if (!dominates(it->second, pit->second)) return false;
}
auto nit = next(it);
if (nit != mp.end()) {
if (!dominates(nit->second, it->second)) return false;
}
return true;
};
// DSU-on-tree: S[u] points to an ordered map size->representative node
vector<map<int,int>*> S(N, nullptr);
vector<char> good(N, 1);
for (int u : order) {
if (!unique_ok[u]) { good[u] = 0; continue; }
for (int v : ch[u]) if (!good[v]) { good[u] = 0; break; }
if (!good[u]) continue;
int heavy = -1;
size_t best = 0;
for (int v : ch[u]) {
if (S[v] && S[v]->size() > best) {
best = S[v]->size();
heavy = v;
}
}
if (heavy == -1) {
S[u] = new map<int,int>();
(*S[u])[sz[u]] = u;
continue;
}
S[u] = S[heavy];
S[heavy] = nullptr;
for (int v : ch[u]) if (v != heavy) {
auto *mv = S[v];
if (!mv) continue;
for (auto &kv : *mv) {
auto ins = S[u]->insert(kv);
if (ins.second) {
if (!check_neighbor(*S[u], ins.first)) { good[u] = 0; break; }
}
if (!good[u]) break;
}
delete mv;
S[v] = nullptr;
if (!good[u]) break;
}
if (!good[u]) continue;
auto insu = S[u]->insert({sz[u], u});
if (insu.second) {
if (!check_neighbor(*S[u], insu.first)) good[u] = 0;
}
}
vector<int> ans(N);
for (int i = 0; i < N; i++) ans[i] = good[i] ? 1 : 0;
if (S[0]) { delete S[0]; S[0] = nullptr; }
return ans;
}