Submission #692509

# Submission time Handle Problem Language Result Execution time Memory
692509 2023-02-01T17:05:53 Z CatalinT Digital Circuit (IOI22_circuit) C++17
Compilation error
0 ms 0 KB
#include "circuit.h"

using namespace std;

#include <iostream>
#include <vector>

template <std::uint32_t mod>
class modint {
private:
    std::uint32_t n;
public:
    modint() : n(0) {};
    modint(std::int64_t n_) : n((n_ >= 0 ? n_ : mod - (-n_) % mod) % mod) {};
    static constexpr std::uint32_t get_mod() { return mod; }
    std::uint32_t get() const { return n; }
    bool operator==(const modint& m) const { return n == m.n; }
    bool operator!=(const modint& m) const { return n != m.n; }
    modint& operator+=(const modint& m) { n += m.n; n = (n < mod ? n : n - mod); return *this; }
    modint& operator-=(const modint& m) { n += mod - m.n; n = (n < mod ? n : n - mod); return *this; }
    modint& operator*=(const modint& m) { n = std::uint64_t(n) * m.n % mod; return *this; }
    modint operator+(const modint& m) const { return modint(*this) += m; }
    modint operator/(const modint& m) const { return modint(*this) * m.inv(); }
    modint operator-(const modint& m) const { return modint(*this) -= m; }
    modint operator*(const modint& m) const { return modint(*this) *= m; }
    modint inv() const { return (*this).pow(mod - 2); }
    modint pow(std::uint64_t b) const {
        modint ans = 1, m = modint(*this);
        while (b) {
            if (b & 1) ans *= m;
            m *= m;
            b >>= 1;
        }
        return ans;
    }
};

using Mint = modint<1'000'002'022>; 

int N, M;

vector<int> P;
vector<int> A;

vector<vector<int>> tree;

vector<Mint> weight;
vector<Mint> tot_prod;

Mint dfs1(int v) {
    if (v >= N) {
        return tot_prod[v] = 1;
    }

    Mint & res = tot_prod[v];
    res = size(tree[v]);
    for (auto u : tree[v]) {
        res *= dfs1(u);
    }
    return res;
}

void dfs2(int v, Mint w) {
    if (v >= N) {
        weight[v - N] = w;
        return;
    }

    int n = size(tree[v]);

    assert(n);

    vector<Mint> suf(n);

    suf[n-1] = tot_prod[tree[v].back()];
    for (int i = n - 2; i >= 0; i--) {
        suf[i] = tot_prod[tree[v][i]] * suf[i+1];
    }    

    Mint pref = 1;
    for (int i = 0; i < n; i++) {
        Mint level = (i + 1 < n ? suf[i+1]: 1);
        level *= pref;
        dfs2(tree[v][i], w * level); 
        pref *= tot_prod[tree[v][i]];
    }
}

void init(int N, int M, std::vector<int> P, std::vector<int> A) {
    ::N = N;
    ::M = M;
    ::P = P;
    ::A = A;

    tree.resize(N);

    for (int i = 1; i < N + M; i++) {
        tree[P[i]].push_back(i);
    }

    tot_prod.resize(N + M);
    dfs1(0);
    cerr << "dfs1 done\n";
    weight.resize(M);
    dfs2(0, 1);
    cerr << "dfs2 done\n";

    // for (int i = 0; i < N; i++) {
    //     cerr << "tot_prod[" << i << "] = " << tot_prod[i].get() << "\n";
    // }

    // for (int i = 0; i < M; i++) {
    //     cerr << "weight[" << i << "] = " << weight[i].get() << "\n";
    // }
}

int count_ways(int L, int R) {
    // naive
    // cerr << "count: " << L << " " << R << "\n";
    L -= N;
    R -= N;

    for (int i = L; i <= R; i++) {
        A[i] ^= 1;
    }

    Mint res = 0;
    for (int i = 0; i < M; i++)
        res += weight[i] * Mint(A[i]);

    return res.get();
}

Compilation message

circuit.cpp: In function 'void dfs2(int, Mint)':
circuit.cpp:71:5: error: 'assert' was not declared in this scope
   71 |     assert(n);
      |     ^~~~~~
circuit.cpp:6:1: note: 'assert' is defined in header '<cassert>'; did you forget to '#include <cassert>'?
    5 | #include <iostream>
  +++ |+#include <cassert>
    6 | #include <vector>