#include <bits/stdc++.h>
using namespace std;
using ll = long long; using pii = pair<ll,ll>;
const ll Nm = 2e5+5;
vector<bool> ans(Nm,1);
vector<ll> radj,C;
vector<ll> fadj[Nm];
ll clen[Nm]; //0 -> leaf, 1 -> leaf+1 edge, 2 -> longer
ll col2[Nm]; //color of the 2
ll lbot[Nm]; //location of bottom
set<ll> clin[Nm];
ll N,M;
void dfs(ll x) {
for (ll y: fadj[x]) {
dfs(y);
if (ans[y]==0) {
ans[x]=0;
}
}
if (ans[x]==0) {
return;
}
if (fadj[x].size()==0) {
clen[x]=0;
ans[x]=1;
return;
}
clen[x]=1;
ll n1in = 0; //number of >=1s in fadj[x]
ll n2in = 0; //number of >=2s in fadj[x]
for (ll y: fadj[x]) {
clin[x].insert(C[y]);
if (clen[y]>=1) {
n1in++;
}
if (clen[y]>=2) {
n2in++;
}
}
if (clin[x].size()!=fadj[x].size() || n2in>=2) {
ans[x]=0;
return;
}
if (n1in>=2 && radj[x]!=-1) {
ans[radj[x]]=0;
}
if (n2in==0) {
if (n1in==0) {
//only leaves in then
ans[x]=1; //will retrieve later
} else if (n1in==1) {
clen[x]=2;
set<ll> cprev;
ll yc = -1;
for (ll y: fadj[x]) {
if (clen[y]==1) {
yc=y;
}
}
assert(yc!=-1);
col2[x]=C[yc];
lbot[x]=yc;
for (ll c0: clin[yc]) {
if (clin[x].find(c0)==clin[x].end()) {
ans[x]=0;
return;
}
}
ans[x]=1;
} else {
vector<pii> vst;
for (ll y: fadj[x]) {
if (clen[y]==1) {
vst.push_back({clin[y].size(),y});
}
}
sort(vst.begin(),vst.end());
ll K = vst.size();
for (ll k=0;k<(K-1);k++) {
ll yc = vst[k].second; ll zc = vst[k+1].second;
for (ll c0: clin[yc]) {
if (clin[zc].find(c0)==clin[zc].end()) {
ans[x]=0;
return;
}
}
}
ll yc = vst[K-1].second; ll zc = x;
for (ll c0: clin[yc]) {
if (clin[zc].find(c0)==clin[zc].end()) {
ans[x]=0;
return;
}
}
ans[x]=1;
return;
}
} else {
ll y2 = -1;
for (ll y: fadj[x]) {
if (clen[y]==2) {
y2 = y;
}
}
assert(y2!=-1);
if (col2[y2]!=C[y2]) {
ans[x]=0;
return;
}
col2[x]=col2[y2];
lbot[x]=lbot[y2];
for (ll c0: clin[y2]) {
if (clin[x].find(c0)==clin[x].end()) {
ans[x]=0;
return;
}
}
if (n1in==1) {
ans[x]=1;
return;
}
vector<pii> vst;
for (ll y: fadj[x]) {
if (clen[y]==1) {
vst.push_back({clin[y].size(),y});
}
}
sort(vst.begin(),vst.end());
ll K = vst.size();
for (ll k=0;k<(K-1);k++) {
ll yc = vst[k].second; ll zc = vst[k+1].second;
for (ll c0: clin[yc]) {
if (clin[zc].find(c0)==clin[zc].end()) {
ans[x]=0;
return;
}
}
}
ll yc = vst[K-1].second; ll zc = lbot[x];
for (ll c0: clin[yc]) {
if (clin[zc].find(c0)==clin[zc].end()) {
ans[x]=0;
return;
}
}
ans[x]=1;
}
}
vector<int> beechtree(int _N, int _M, vector<int> _P, vector<int> _C) {
N=_N; M=_M;
for (ll x=0;x<N;x++) {
radj.push_back(_P[x]);
if (x!=0) {
fadj[_P[x]].push_back(x);
}
}
for (ll x: _C) {
C.push_back(x);
}
dfs(0);
vector<int> vans;
for (ll x=0;x<N;x++) {
vans.push_back(ans[x]);
}
return vans;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |