Submission #1284748

#TimeUsernameProblemLanguageResultExecution timeMemory
1284748PlayVoltzDigital Circuit (IOI22_circuit)C++20
2 / 100
303 ms12212 KiB
#include "circuit.h"
#include <bits/stdc++.h>
using namespace std;

#define int long long
#define pii pair<int,int>
#define mp make_pair
#define pb push_back
#define fi first
#define se second

const int MOD = 1000002022;

int n, m;

vector<int> vect, state, all;
vector<vector<int> > graph;

int inv(int num){
	int p=MOD, res=1, y=0;
	while (num>1){
		int q=num/p, t=p;
		p=num%p, num=t, t=y, y=res-q*y, res=t;
	}
	if (res<0)res+=MOD;
	return res;
}

vector<int> merge(vector<int> a, vector<int> b){
	return {(a[0]+b[0])%MOD, (a[1]+b[1])%MOD};
}

struct node{
	int s, e, m, lazy;
	vector<int> val;
	node *l, *r;
	
	void propagate(){
		if (lazy)swap(val[0], val[1]);
		if (s!=e){
			l->lazy^=lazy;
			r->lazy^=lazy;
		}
		lazy=0;
	}
	node(int S, int E){
		s=S, e=E, m=(s+e)/2, lazy=0;
		val.resize(2, 0);
		if (s==e)val[state[s]]=vect[n+s];
		else{
			l = new node(s, m);
			r = new node(m+1, e);
			val=merge(l->val, r->val);
		}
	}
	void up(int left, int right){
		propagate();
		if (s==left && e==right)lazy^=1;
		else{
			if (right<=m)l->up(left, right);
			else if (left>m)r->up(left, right);
			else l->up(left, m), r->up(m+1, right);
			r->propagate(), l->propagate();
			val=merge(l->val, r->val);
		}
	}
	int query(int left, int right){
		propagate();
		if (s==left && e==right)return val[1];
		if (right<=m)return l->query(left, right);
		else if (left>m)return r->query(left, right);
		else return (l->query(left, m)+r->query(m+1, right))%MOD;
	}
}*st;

void dfs(int node, int p){
	if (node>=n)return;
	all[node]=graph[node].size();
	for (auto num:graph[node])dfs(num, node), all[node]=(all[node]*all[num])%MOD;
}

void dfs2(int node, int p){
	if (node>=n)return;
	vector<int> psum(graph[node].size()+1, 1), ssum(graph[node].size()+2, 1);
	for (int i=0; i<graph[node].size(); ++i)psum[i+1]=(psum[i]*all[graph[node][i]])%MOD;
	for (int i=graph[node].size()-1; i>=1; --i)ssum[i]=(ssum[i+1]*all[graph[node][i]])%MOD;
	for (int i=0; i<graph[node].size(); ++i)vect[graph[node][i]]=(psum[i]*ssum[i+2])%MOD;
}

void init(signed N, signed M, vector<signed> p, vector<signed> A) {
	n=N, m=M;
	state.resize(m);
	for (int i=0; i<m; ++i)state[i]=A[i];
	graph.resize(n+m);
	vect.resize(n+m, 1);
	all.resize(n+m, 1);
	for (int i=1; i<n+m; ++i)graph[p[i]].pb(i);
	dfs(0, 0);
	dfs2(0, 0);
	st = new node(0, m-1);
}

signed count_ways(signed l, signed r){
	st->up(l-n, r-n);
	return st->query(0, m-1);
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...