#include "circuit.h"
#include <cassert>
#include <cstdio>
#include <vector>
#include <iostream>
using namespace std;
vector<int>adj[200005];
long long mod=1000002022;
long long contr[200005];
long long modes[200005];
void dfs(int curr)
{
modes[curr]=adj[curr].size();
if(adj[curr].size()==0)
{
modes[curr]=1;
return;
}
for(auto k:adj[curr])
{
dfs(k);
modes[curr]=(modes[curr]*modes[k])%mod;
}
}
void dfs2(int curr)
{
if(adj[curr].size()==0)
{
return;
}
vector<long long>prefix;
vector<long long>suffix;
vector<long long>vec;
for(auto k:adj[curr])
{
vec.push_back(modes[k]);
}
prefix=vec;
suffix=vec;
for(int i=1;i<vec.size();i++)
{
prefix[i]=(prefix[i-1]*prefix[i])%mod;
}
for(int i=vec.size()-2;i>=0;i--)
{
suffix[i]=(suffix[i+1]*suffix[i])%mod;
}
for(int i=0;i<adj[curr].size();i++)
{
int k=adj[curr][i];
long long a,b;
if(i!=0)
{
a=prefix[i-1];
}
else
{
a=1;
}
if(i!=adj[curr].size()-1)
{
b=suffix[i+1];
}
else
{
b=1;
}
contr[k]=((a*b)%mod*contr[curr])%mod;
dfs2(k);
}
}
int state[200005];
int n,m;
long long sum0[1000005];
long long sum1[1000005];
long long flip[1000005];
void build(int node,int st,int dr)
{
if(st==dr)
{
if(state[st]==0)
{
sum0[node]=contr[st+n];
sum1[node]=0;
}
else
{
sum0[node]=0;
sum1[node]=contr[st+n];
}
return;
}
int mij=(st+dr)/2;
build(node*2,st,mij);
build(node*2+1,mij+1,dr);
sum0[node]=sum0[node*2]+sum0[node*2+1];
sum1[node]=sum1[node*2]+sum1[node*2+1];
}
void init(int N, int M, std::vector<int> P, std::vector<int> A)
{
m=M;
n=N;
for(int i=1; i<n+m; i++)
{
adj[P[i]].push_back(i);
}
for(int i=0;i<m;i++)
{
state[i]=A[i];
}
dfs(0);
contr[0]=1;
dfs2(0);
build(1,0,m-1);
}
void push(int node,int st,int dr)
{
if(flip[node]==1)
{
swap(sum0[node],sum1[node]);
if(st!=dr)
{
flip[node*2]^=flip[node];
flip[node*2+1]^=flip[node];
}
flip[node]=0;
}
}
void update(int node,int st,int dr,int qst,int qdr)
{
push(node,st,dr);
if(st>dr || st>qdr || qst>dr || qst>qdr)
{
return ;
}
if(qst<=st && dr<=qdr)
{
flip[node]=1;
push(node,st,dr);
return;
}
int mij=(st+dr)/2;
update(node*2,st,mij,qst,qdr);
update(node*2+1,mij+1,dr,qst,qdr);
sum0[node]=sum0[node*2]+sum0[node*2+1];
sum1[node]=sum1[node*2]+sum1[node*2+1];
}
int count_ways(int L, int R)
{
long long total=0;
update(1,0,m-1,L-n,R-n);
return sum1[1]%mod;
}