Submission #629743

#TimeUsernameProblemLanguageResultExecution timeMemory
629743MatheusLealVDigital Circuit (IOI22_circuit)C++17
100 / 100
1073 ms27356 KiB
#include "circuit.h"
#include <bits/stdc++.h>
#define N 200020
#define f first
#define s second
#define pb push_back
#define mp make_pair
#define all(x) begin(x), end(x)
#define sz(x) (int)(x).size()
using namespace std;
typedef long long ll;
const ll mod = 1000002022;
 
ll yes[N], no[N], state[N],pai[N], n,m;
ll save[N], sum[N];
vector<ll> grafo[N];

ll pref[N], ans = 0;

void solve(){
	for(ll x=n;x<n+m+1;x++)yes[x]=state[x],no[x]=1-state[x],save[x]=sum[x]=1;
	pref[0]=1;
	for(ll x = n-1; x >= 0; x--){
		ll sum_cost = 1, C = sz(grafo[x]);
		no[x]=yes[x]=0;
		for(auto v: grafo[x]){
			ll old_yes = yes[x];

 			yes[x] = (1LL*old_yes*sum[v])%mod + (1LL*sum_cost*yes[v])%mod;
 			yes[x] %= mod;
 			pref[v] = sum_cost;
			sum_cost = (1LL*sum_cost*sum[v])%mod;
		}
		reverse(all(grafo[x]));
		ll suf = 1;
		for(auto v: grafo[x]){
			pref[v] = (1LL*pref[v]*suf)%mod;
			suf = (1LL*suf*sum[v])%mod;
		}

		if(save[x] == -1) save[x] = sum_cost;
		assert(save[x] == sum_cost);
		no[x] = (1LL*C*sum_cost%mod - yes[x])%mod;
		no[x] = (no[x] + mod)%mod;
		sum[x] = (yes[x] + no[x])%mod;
	}

	for(ll i = 0; i < n+m; i++){
		if(i)pref[i] = (pref[i] * pref[pai[i]])%mod;
	}
}
ll cnt[4*N][2],lazy[4*N],custo[N];
#define mid ((a+b)/2)
void build(int nod, int a, int b){
	if(a==b){
		cnt[nod][custo[a]] = pref[a+n];
		return;
	}
	build(2*nod,a,mid); build(2*nod+1,mid+1,b);
	cnt[nod][0] = (cnt[2*nod][0]+cnt[2*nod+1][0])%mod;
	cnt[nod][1] = (cnt[2*nod][1]+cnt[2*nod+1][1])%mod;
}
void prop(int nod, int a, int b){
	if(!lazy[nod]) return;
	swap(cnt[nod][0],cnt[nod][1]);
	lazy[nod]=0;
	if(a != b){
		lazy[2*nod] ^= 1;
		lazy[2*nod+1] ^= 1;
	}
}
void upd(int nod, int a, int b, int i, int j){
	prop(nod,a,b);
	if(j < a or i > b) return;
	if(i <= a and j >= b){
		lazy[nod] ^= 1;
		prop(nod,a,b);
		return;
	}
	upd(2*nod,a,mid,i,j); upd(2*nod+1,mid+1,b,i,j);
	cnt[nod][0] = (cnt[2*nod][0] + cnt[2*nod+1][0])%mod;
	cnt[nod][1] = (cnt[2*nod][1] + cnt[2*nod+1][1])%mod;
}
 
void init(int n_, int m_, vector<int> P, vector<int> A) {
	memset(save,-1,sizeof save);
	n = n_, m = m_;
	for(int i = 1; i < n+m; i++){
		grafo[P[i]].pb(i);
		pai[i] = P[i];
	}
	for(int i = 0;i<m;i++) state[i+n] = A[i],custo[i]=A[i];
	solve();

	build(1,0,m-1);
}
 
int count_ways(int L, int R) {

	upd(1,0,m-1,L-n,R-n);
	prop(1,0,m-1);
  	return cnt[1][1];
 
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...