#include "circuit.h"
/*
* Author: Nonoze
* Created: Tuesday 03/03/2026
*/
#include <bits/stdc++.h>
using namespace std;
#ifndef DEBUG
#define dbg(...)
#endif
// #define cout cerr << "OUT: "
#define endl '\n'
#define endlfl '\n' << flush
#define quit(x) return (void)(cout << x << endl)
template<typename T> void read(T& x) { cin >> x; }
template<typename T1, typename T2> void read(pair<T1, T2>& p) { read(p.first), read(p.second); }
template<typename T> void read(vector<T>& v) { for (auto& x : v) read(x); }
template<typename T1, typename T2> void read(T1& x, T2& y) { read(x), read(y); }
template<typename T1, typename T2, typename T3> void read(T1& x, T2& y, T3& z) { read(x), read(y), read(z); }
template<typename T1, typename T2, typename T3, typename T4> void read(T1& x, T2& y, T3& z, T4& zz) { read(x), read(y), read(z), read(zz); }
template<typename T> void print(vector<T>& v) { for (auto& x : v) cout << x << ' '; cout << endl; }
#define sz(x) (int)(x.size())
#define all(x) (x).begin(), (x).end()
#define rall(x) (x).rbegin(), (x).rend()
#define make_unique(v) sort(all(v)), v.erase(unique(all(v)), (v).end())
#define pb push_back
#define mp(a, b) make_pair(a, b)
#define fi first
#define se second
#define cmin(a, b) a = min(a, b)
#define cmax(a, b) a = max(a, b)
#define YES cout << "YES" << endl
#define NO cout << "NO" << endl
#define QYES quit("YES")
#define QNO quit("NO")
const int MOD = 1e9+2022, LOG=20;
template<typename T, auto &md>
struct modint {
#define tpU template<typename U>
using V = conditional_t<sizeof(T) <= 4, uint64_t, __uint128_t>;
make_unsigned_t<T> x;
modint() : x(0) {}
tpU modint(U y) : x(y < md ? y : (y %= md) >=0 ? y : y += md) {}
operator T() const { return x; }
modint operator-() const { return modint() - *this; }
tpU modint operator+(U y) const { return modint(*this) += y; }
tpU modint operator-(U y) const { return modint(*this) -= y; }
tpU modint operator*(U y) const { return modint(*this) *= y; }
tpU modint operator/(U y) const { return modint(*this) /= y; }
tpU modint operator++(U y) { if (++x == md) x = 0; return *this; }
tpU modint operator--(U y) { if (!x--) x = md-1; return *this; }
modint& operator+=(modint y) { if ((x += y.x) >= md) x -= md; return *this; }
modint& operator-=(modint y) { if ((x += md-y.x) >= md) x -= md; return *this; }
modint& operator*=(modint y) { x = x * (V)y.x % md; return *this; }
modint& operator/=(modint y) { return *this *= y.inv(); }
tpU modint pow(U y) const {
if (y < 0) return inv(pow(-y));
modint res(1), t(*this);
for (; y; y = y>>1 , t *= t) if (y & 1) res *= t;
return res;
}
modint inv() const { return pow(md-2); }
// precompute factorials
static void precompute_fact(T n) {
if (n < fact_.size()) return;
T i=sz(fact_);
fact_.resize(n+1);
ifact_.resize(n+1);
for (; i<=n; i++) {
fact_[i] = V(fact_[i-1])*i % md;
ifact_[i]= V(ifact_[i-1])*inv(i) % md;
}
}
static modint combi(T k, T n) {
if (n < 0 || k < 0 || k > n) return 0;
return fact(n) * ifact(k) * ifact(n-k);
}
static modint inv(T x) { return modint(x).inv(); }
tpU static modint pow(U x, U y) { return modint(x).pow(y); }
static modint fact(T x) { precompute_fact(x); return fact_[x]; }
static modint ifact(T x) { precompute_fact(x); return ifact_[x]; }
static inline vector<T> fact_={1, 1}, ifact_={1, 1};
};
using mint = modint<uint32_t, MOD>;
template<class T, class U>
struct segtree {
vector<T> st;
vector<U> lazy;
int n;
T idElement;
U idUpdate;
void create(int _n, T idEl, U idUp) {
n=_n;
idElement=idEl;
idUpdate=idUp;
st.resize(n*4, idElement);
lazy.resize(n*4, idUpdate);
}
segtree() {}
segtree(int _n, T idEl, U idUp) {
create(_n, idEl, idUp);
}
segtree(int _n, T idEl, U idUp, vector<T> &creation) {
assert(sz(creation)==_n);
create(_n, idEl, idUp);
build(0, 0, n-1, creation);
}
void build(int v, int l, int r, vector<T> &creation) {
if (l==r) {
st[v]=creation[l];
return;
}
int mid=(l+r)/2;
build(v*2+1, l, mid, creation);
build(v*2+2, mid+1, r, creation);
st[v]=combine(st[2*v+1], st[2*v+2]);
}
void propagate(int v, int l, int r) {
if (lazy[v]==idUpdate) return;
st[v]=apply(st[v], lazy[v], l, r);
if (l!=r) {
int mid=(l+r)/2;
lazy[2*v+1]=combinelazy(lazy[2*v+1], lazy[v], l, mid);
lazy[2*v+2]=combinelazy(lazy[2*v+2], lazy[v], mid+1, r);
}
lazy[v]=idUpdate;
}
T query(int v, int l, int r, int ql, int qr) {
propagate(v, l, r);
if (l>qr || r<ql) return idElement;
if (l>=ql && r<=qr) return st[v];
int mid=(l+r)/2;
T s1=query(v*2+1, l, mid, ql, qr);
T s2=query(v*2+2, mid+1, r, ql, qr);
return combine(s1, s2);
}
void update(int v, int l, int r, int ql, int qr, U upd) {
propagate(v, l, r);
if (l>qr || r<ql) return;
if (l>=ql && r<=qr) {
lazy[v]=upd;
propagate(v, l, r);
return;
}
int mid=(l+r)/2;
update(v*2+1, l, mid, ql, qr, upd);
update(v*2+2, mid+1, r, ql, qr, upd);
st[v]=combine(st[v*2+1], st[v*2+2]);
}
void build(vector<T> &creation) {
build(0, 0, n-1, creation);
}
T query(int l, int r) {
return query(0, 0, n-1, l, r);
}
void update(int l, int r, U upd) {
update(0, 0, n-1, l, r, upd);
}
void update(int point, U upd) {
update(point, point, upd);
}
T combine(T l, T r) { // between two children / queries
T ans={l.fi+r.fi, l.se+r.se};
return ans;
}
T apply(T curr, U upd, int l, int r) { // apply update to a node (from lazy)
T ans=curr;
ans={curr.se, curr.fi};
return ans;
}
U combinelazy(U old_upd, U new_upd, int l, int r) { // combine two lazy updates
U ans=old_upd;
ans=ans^new_upd;
return ans;
}
};
vector<int> adj[200005], a, p;
vector<mint> nb, poss;
int n, m, root=0;
segtree<pair<mint, mint>, int> st;
mint dfs(int u) {
if (adj[u].empty()) return poss[u]=1;
mint res=sz(adj[u]);
for (auto &v: adj[u]) res*=dfs(v);
return poss[u]=res;
}
void dfs2(int u, mint act=1) {
if (adj[u].empty()) {
nb[u-n]=act;
return;
}
vector<mint> pref(sz(adj[u]), 1), suff(sz(adj[u]), 1); pref[0]=poss[adj[u][0]], suff.back()=poss[adj[u].back()];
for (int i=1; i<sz(adj[u]); i++) {
pref[i]=pref[i-1]*poss[adj[u][i]];
}
for (int i=sz(adj[u])-2; i>=0; i--) {
suff[i]=suff[i+1]*poss[adj[u][i]];
}
for (int i=0; i<sz(adj[u]); i++) {
mint nact=act;
if (i) nact*=pref[i-1];
if (i!=sz(adj[u])-1) nact*=suff[i+1];
dfs2(adj[u][i], nact);
}
}
void init(int N, int M, vector<int> P, vector<int> aa) {
a=aa, p=P;
n=N, m=M;
for (int i=1; i<n+m; i++) adj[p[i]].push_back(i);
poss.resize(n+m); dfs(root);
nb.resize(m);
dfs2(root);
st.create(m, {0, 0}, 0);
vector<pair<mint, mint>> creation(m);
for (int i=0; i<m; i++) {
if (a[i]) creation[i]={nb[i], 0};
else creation[i]={0, nb[i]};
}
st.build(creation);
}
int count_ways(int L, int R) {
st.update(L-n, R-n, 1);
return st.query(0, m-1).fi;
}