#include "beechtree.h"
#include <bits/stdc++.h>
using namespace std;
struct treeseq {
list<pair<int, unordered_map<int, int>>> seq;
int size = 1;
treeseq() {
seq.insert(seq.end(), make_pair(1, unordered_map<int, int>()));
}
};
vector<vector<pair<int, int>>> children;
vector<int> subtreesize;
vector<int> ret;
int getsubtreesize(int node, int parent) {
int res = 1;
for(auto it = children[node].begin(); it != children[node].end(); it++) {
if(it->second == parent) {
children[node].erase(it);
break;
}
}
for(pair<int, int> e : children[node]) {
res += getsubtreesize(e.second, node);
}
subtreesize[node] = res;
return res;
}
treeseq* merge(treeseq* left, treeseq* right) {
if(left == nullptr && right == nullptr) return nullptr;
if(left == nullptr) {
delete right;
return nullptr;
}
if(right == nullptr) {
delete left;
return nullptr;
}
if(left->size < right->size) {
treeseq* temp = left;
left = right;
right = temp;
}
auto it1 = left->seq.begin();
auto it2 = right->seq.begin();
while(it1 != left->seq.end() && it2 != right->seq.end()) {
if(it1->first == it2->first) {
for(pair<int, int> p : it1->second) {
if(p.second != it2->second[p.first]) {
delete left;
delete right;
return nullptr;
}
}
it1++;
it2++;
continue;
}
auto it_adding = it1;
auto it_removing = it1;
if(it1->first > it2->first) {
it_adding = it2++;
it_removing = it1;
left->seq.insert(it_removing, make_pair(it2->first, unordered_map<int, int>()));
}
else {
it_adding = it1++;
it_removing = it2;
right->seq.insert(it_removing, make_pair(it1->first, unordered_map<int, int>()));
}
auto it_split = it_removing;
it_split--;
for(pair<int, int> p : it_adding->second) {
it_removing->second[p.first] -= p.second;
if(it_removing->second[p.first] < 0) {
delete left;
delete right;
return nullptr;
}
it_split->second[p.first] += p.second;
}
}
return left;
}
treeseq* dp(int node) {
treeseq* sq = new treeseq;
vector<treeseq*> to_merge;
for(pair<int, int> e : children[node]) to_merge.push_back(dp(e.second));
for(treeseq* m : to_merge) {
sq = merge(sq, m);
}
for(int i = 1; i < children[node].size(); i++) {
if(children[node][i - 1].first == children[node][i].first) {
delete sq;
return nullptr;
}
}
if(sq == nullptr) return nullptr;
sq->size = subtreesize[node];
if(subtreesize[node] == 1) {
ret[node] = 1;
return sq;
}
pair<int, int> biggest_subtree = make_pair(-1, -1);
for(pair<int, int> e : children[node]) biggest_subtree = max(biggest_subtree, make_pair(subtreesize[e.second], e.second));
auto it = children[biggest_subtree.second].begin();
auto newnodeit = sq->seq.end();
newnodeit = sq->seq.insert(newnodeit, make_pair(subtreesize[node], unordered_map<int, int>()));
for(pair<int, int> e : children[node]) {
if(it != children[biggest_subtree.second].end() && it->first < e.first) {
delete sq;
return nullptr;
}
else if(it != children[biggest_subtree.second].end() && it->first == e.first) {
if(subtreesize[e.second] - subtreesize[it->second] < 0) {
delete sq;
return nullptr;
}
newnodeit->second[e.first] = subtreesize[e.second] - subtreesize[it->second];
it++;
}
else {
newnodeit->second[e.first] = subtreesize[e.second];
}
}
ret[node] = 1;
return sq;
}
vector<int> beechtree(int N, int M, vector<int> P, vector<int> C) {
children.resize(N);
subtreesize.resize(N);
ret.resize(N);
for(int i = 1; i < N; i++) {
children[P[i]].push_back(make_pair(C[i], i));
}
for(int i = 0; i < N; i++) {
sort(children[i].begin(), children[i].end());
}
getsubtreesize(0, -1);
dp(0);
return ret;
}