#include "beechtree.h"
#include<iostream>
#include<algorithm>
#include<vector>
#include<queue>
#include<set>
#include<map>
#include<tuple>
using namespace std;
const int MAX_N=2e5+5;
int n,m;
vector<pair<int,int>>g[MAX_N];
int par[MAX_N];
int col[MAX_N];
bool good[MAX_N];
void dfsgood(int u)
{
    set<int>s;
    for(auto [v,edge]:g[u])
    {
        if(s.count(edge))good[u]=0;
        s.insert(edge);
        dfsgood(v);
        good[u]=min(good[u],good[v]);
    }
}
vector<int>nodes;
bool odd;
void dfsnode(int u)
{
    if(g[u].size()%2==1)odd=1;
    for(auto [v,edge]:g[u])
    {
        nodes.push_back(v);
        dfsnode(v);
    }
}
map<int,vector<pair<int,int>>>eachcol;
map<int,bool>touched[MAX_N];
set<int>used;
int seq[MAX_N];
int szseq;
map<int,int>cntcols;
bool solve(int r,int fuck)
{
    //odd=0;
    //if(good[r]==0)return 0;
    used.clear();
    nodes.clear();
    cntcols.clear();
    eachcol.clear();
    szseq=0;
    dfsnode(r);
    if(nodes.size()<=1)return 1;
    sort(nodes.rbegin(),nodes.rend());
    for(int u:nodes)touched[u].clear();
    map<int,int>numroots;
    int important;
    int ako=0;
    for(int u:nodes)
    {
        int curcol=col[u];
        if(touched[u][curcol]==1)continue;
        int sz=0;
        while(1)
        {
            if(sz>0)eachcol[curcol].push_back({sz,u});
            touched[u][curcol]=1;
            sz++;
            if(col[u]!=curcol or u==r)break;
            u=par[u];
        }
        numroots[curcol]+=sz;
    }
    for(auto [curcol,roots]:numroots)
    {
        if(roots>ako)
        {
            ako=roots;
            important=curcol;
        }
    }
    for(auto&[curcol,blockcol]:eachcol)
    {
        sort(blockcol.begin(),blockcol.end());
    }
    vector<pair<int,int>>blocks;
    map<int,vector<int>>curblock;
    while(szseq!=nodes.size()+1)
    {
        blocks.clear();
        curblock.clear();
        int rem=0;
        for(auto&[curcol,blockcol]:eachcol)
        {
            while(blockcol.size() && used.count(blockcol.back().second))blockcol.pop_back();
            int cursz;
            if(blockcol.size())
            {
                rem++;
                cursz=blockcol.back().first;
            }
            while(blockcol.size() && blockcol.back().first==cursz)
            {
                if(used.count(blockcol.back().second)==0)curblock[curcol].push_back(blockcol.back().second);
                blockcol.pop_back();
            }
            blocks.push_back({curblock[curcol].size(),curcol});
        }
        if(rem==0)break;
        sort(blocks.begin(),blocks.end());
        for(auto [SEX,curcol]:blocks)
        {
            vector<tuple<int,int,int>>candidates;
            for(int u:curblock[curcol])
            {
                if(used.count(u))continue;
                int where=-1e9;
                if(u!=r)where=seq[par[u]];
                candidates.push_back({where,(u==r ? -1e9 : (fuck==0 ? -col[u] : col[u])),u});
            }
            sort(candidates.begin(),candidates.end());
            for(auto [SEX3,SEX2,x]:candidates)
            {
                if(szseq==0)
                {
                    if(x!=r)return 0;
                }
                else
                {
                    if(seq[cntcols[col[x]]]!=par[x])return 0;
                    cntcols[col[x]]++;
                }
                seq[szseq++]=x;
                used.insert(x);
            }
        }
    }
    vector<pair<int,int>>candidates;
    for(int u:nodes)
    {
        if(used.count(u))continue;
        int where=-1e9;
        if(u!=r)where=seq[par[u]];
        candidates.push_back({where,u});
    }
    sort(candidates.begin(),candidates.end());
    for(auto [SEX2,x]:candidates)
    {
        if(szseq==0)
        {
            if(x!=r)return 0;
        }
        else
        {
            if(seq[cntcols[col[x]]]!=par[x])return 0;
            cntcols[col[x]]++;
        }
        seq[szseq++]=x;
        used.insert(x);
    }
    return (szseq==(nodes.size()+1));
}
std::vector<int> beechtree(int N, int M, std::vector<int> P, std::vector<int> C)
{
    n=N;
    m=M;
    vector<int>ans;
    ans.resize(n);
    for(int i=1;i<n;i++)
    {
        par[i]=P[i];
        col[i]=C[i];
        g[P[i]].push_back({i,C[i]});
    }
    for(int i=0;i<n;i++)
    {
        good[i]=1;
    }
    dfsgood(0);
    for(int i=0;i<n;i++)
    {
        ans[i]=max(solve(i,1),solve(i,0));
    }
    return ans;
}
| # | Verdict  | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict  | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict  | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict  | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict  | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict  | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict  | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict  | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict  | Execution time | Memory | Grader output | 
|---|
| Fetching results... | 
| # | Verdict  | Execution time | Memory | Grader output | 
|---|
| Fetching results... |