Submission #1234028

#TimeUsernameProblemLanguageResultExecution timeMemory
1234028trimkusDigital Circuit (IOI22_circuit)C++20
0 / 100
394 ms8980 KiB
#include "circuit.h"
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const ll MOD = 1000002022;
const int MAXN = 1e5 + 5;
int A[MAXN];
vector<int> adj[MAXN];
int tin[MAXN], tout[MAXN];
int N, M;
vector<pair<ll, ll>> arr[MAXN];
pair<ll, ll> calc_init(int node) {
  if (node >= N) {
    return {A[node] ^ 1, A[node] ^ 0};
  }
  int m = adj[node].size();
  arr[node] = vector<pair<ll, ll>>(m);
  vector<vector<ll>> dp(m + 1, vector<ll>(m + 1));
  dp[0][0] = 1;
  for (int i = 0; i < m; ++i) {
    arr[node][i] = calc_init(adj[node][i]);
  }
  for (int i = 1; i <= m; ++i) {
    dp[i][0] = (dp[i - 1][0] * arr[node][i - 1].first) % MOD;
  }
  for (int i = 1; i <= m; ++i) {
    for (int j = 1; j <= m; ++j) {
      dp[i][j] = (
        dp[i - 1][j] * arr[node][i - 1].first % MOD 
        +
        dp[i - 1][j - 1] * arr[node][i - 1].second % MOD
      ) % MOD;
    }
  }
  ll res = 0;
  for (int i = 1; i <= m; ++i) {
    res = (res + dp[m][i] * i) % MOD;
  }
  ll tot = m;
  for (int i = 0; i < m; ++i) {
    tot = (tot * (arr[node][i].first + arr[node][i].second)) % MOD;
  }
  tot = (tot - res + MOD) % MOD;
  return {tot, res};
}

void init(int _N, int _M, std::vector<int> P, std::vector<int> _A) {
  N = _N;
  M = _M;
  for (int i = N; i < N + M; ++i) {
    A[i] = _A[i - N];
  }
  for (int i = 1; i < N + M; ++i) {
    adj[P[i]].push_back(i);
  }
  int t = 0;
  auto dfs = [&](auto& dfs, int i) -> void {
    tin[i] = t++;
    for (auto& u : adj[i]) {
      dfs(dfs, u);
    }
    tout[i] = t++;
  };
  dfs(dfs, 0);
  calc_init(0);
}

pair<ll, ll> calc(int node, int TIN, int TOUT) {
  // cerr << node << " " << TIN << " " << TOUT << "\n";
  if (node >= N) {
    // assert(tin[node] == TIN && tout[node] == TOUT);
    return {A[node] ^ 1, A[node] ^ 0};
  }
  int m = adj[node].size();
  vector<vector<ll>> dp(m + 1, vector<ll>(m + 1));
  dp[0][0] = 1;
  for (int i = 0; i < m; ++i) {
    int j = adj[node][i];
    if (tin[j] <= TIN && TOUT <= tout[j]) {
      arr[node][i] = calc(j, TIN, TOUT);
    }
  }
  for (int i = 1; i <= m; ++i) {
    dp[i][0] = (dp[i - 1][0] * arr[node][i - 1].first) % MOD;
  }
  for (int i = 1; i <= m; ++i) {
    for (int j = 1; j <= m; ++j) {
      dp[i][j] = (
        dp[i - 1][j] * arr[node][i - 1].first % MOD 
        +
        dp[i - 1][j - 1] * arr[node][i - 1].second % MOD
      ) % MOD;
    }
  }
  ll res = 0;
  for (int i = 1; i <= m; ++i) {
    res = (res + dp[m][i] * i) % MOD;
  }
  ll tot = m;
  for (int i = 0; i < m; ++i) {
    tot = (tot * (arr[node][i].first + arr[node][i].second)) % MOD;
  }
  tot = (tot - res + MOD) % MOD;
  return {tot, res};
}

int count_ways(int L, int R) {
  for (int i = L; i <= R; ++i) {
    A[i] ^= 1;
  }
  return calc(0, tin[L], tout[L]).second;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...