#include "circuit.h"
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const ll maxn=2e5,mod=1e9+2022;
vector<int>adj[maxn+2];
int n,m;
ll tot[maxn+2];
ll dp[maxn+2];
void dfs(int cur){
tot[cur]=adj[cur].size();
if(cur>=n){
tot[cur]=1;
}
for(auto x : adj[cur]){
dfs(x);
tot[cur]=(tot[cur]*tot[x])%mod;
}
}
void solve(int cur,int val){
dp[cur]=val;
if(adj[cur].empty())return;
ll pref[adj[cur].size()+2],suf[adj[cur].size()+2];
pref[0]=tot[adj[cur][0]];
for(int q=1;q<adj[cur].size();q++){
pref[q]=(pref[q-1]*tot[adj[cur][q]])%mod;
}
suf[adj[cur].size()]=1;
for(int q=adj[cur].size()-1;q>=0;q--){
suf[q]=(suf[q+1]*tot[adj[cur][q]])%mod;
}
for(int q=0;q<adj[cur].size();q++){
int baru=(val*suf[q+1])%mod;
if(q){
baru=(baru*pref[q-1])%mod;
}
//if(cur==0 && q)cout<<q<<' '<<adj[cur][q]<<endl;
solve(adj[cur][q],baru);
}
}
struct seg{
int l,r;
ll ny,mt; bool tog=false;
seg *lf,*rg;
void build(int x,int y){
l=x,r=y;
if(l==r){
ny=0,mt=dp[x];
return;
}
int mid=(l+r)/2;
lf=new seg(),rg=new seg();
lf->build(l,mid),rg->build(mid+1,r);
ny=lf->ny+rg->ny,mt=lf->mt+rg->mt;
}
void apply(){
tog=!tog;
swap(ny,mt);
}
void prop(){
if(tog==false)return;
lf->apply(),rg->apply();
tog=false;
}
void update(int posl,int posr){
if(l>posr || r<posl)return;
if(l>=posl && r<=posr){
apply(); return;
}
prop();
lf->update(posl,posr),rg->update(posl,posr);
ny=lf->ny+rg->ny,mt=lf->mt+rg->mt;
}
ll query(){
return ny;
}
};
seg slv;
void init(int N, int M, vector<int> P, vector<int> A) {
n=N,m=M;
for(int q=1;q<n+m;q++){
adj[P[q]].push_back(q);
}
dfs(0);
solve(0,1);
slv.build(n,n+m-1);
for(int q=n;q<n+m;q++){
if(A[q-n]==1){
slv.update(q,q);
}
}
}
int count_ways(int L, int R) {
slv.update(L,R);
return slv.query()%mod;
}