#include "circuit.h"
#include <cassert>
#include <cstdio>
#include <vector>
#include <iostream>
using namespace std;
vector<int>adj[200005];
long long mod=1000002022;
long long contr[200005];
long long modes[200005];
void dfs(int curr)
{
modes[curr]=adj[curr].size();
if(adj[curr].size()==0)
{
modes[curr]=1;
return;
}
for(auto k:adj[curr])
{
dfs(k);
modes[curr]=(modes[curr]*modes[k])%mod;
}
}
void dfs2(int curr)
{
if(adj[curr].size()==0)
{
return;
}
vector<long long>prefix;
vector<long long>suffix;
vector<long long>vec;
for(auto k:adj[curr])
{
vec.push_back(modes[k]);
}
prefix=vec;
suffix=vec;
for(int i=1;i<vec.size();i++)
{
prefix[i]=(prefix[i-1]*prefix[i])%mod;
}
for(int i=vec.size()-2;i>=0;i--)
{
suffix[i]=(suffix[i+1]*suffix[i])%mod;
}
for(int i=0;i<adj[curr].size();i++)
{
int k=adj[curr][i];
long long a,b;
if(i!=0)
{
a=prefix[i-1];
}
else
{
a=1;
}
if(i!=adj[curr].size()-1)
{
b=suffix[i+1];
}
else
{
b=1;
}
contr[k]=((a*b)%mod*contr[curr])%mod;
dfs2(k);
}
}
int state[200005];
int n,m;
void init(int N, int M, std::vector<int> P, std::vector<int> A)
{
m=M;
n=N;
for(int i=1; i<n+m; i++)
{
adj[P[i]].push_back(i);
}
for(int i=0;i<m;i++)
{
state[i]=A[i];
}
dfs(0);
contr[0]=1;
dfs2(0);
}
int count_ways(int L, int R)
{
long long total=0;
for(int i=L-n;i<=R-n;i++)
{
state[i]^=1;
}
for(int i=0;i<m;i++)
{
total=(total+state[i]*contr[i+n])%mod;
}
return total%mod;
}