제출 #692514

#제출 시각아이디문제언어결과실행 시간메모리
692514CatalinT디지털 회로 (IOI22_circuit)C++17
100 / 100
1938 ms23728 KiB
#include "circuit.h"

using namespace std;

#include <iostream>
#include <vector>
#include <cassert>

template <std::uint32_t mod>
class modint {
private:
    std::uint32_t n;
public:
    modint() : n(0) {};
    modint(std::int64_t n_) : n((n_ >= 0 ? n_ : mod - (-n_) % mod) % mod) {};
    static constexpr std::uint32_t get_mod() { return mod; }
    std::uint32_t get() const { return n; }
    bool operator==(const modint& m) const { return n == m.n; }
    bool operator!=(const modint& m) const { return n != m.n; }
    modint& operator+=(const modint& m) { n += m.n; n = (n < mod ? n : n - mod); return *this; }
    modint& operator-=(const modint& m) { n += mod - m.n; n = (n < mod ? n : n - mod); return *this; }
    modint& operator*=(const modint& m) { n = std::uint64_t(n) * m.n % mod; return *this; }
    modint operator+(const modint& m) const { return modint(*this) += m; }
    modint operator/(const modint& m) const { return modint(*this) * m.inv(); }
    modint operator-(const modint& m) const { return modint(*this) -= m; }
    modint operator*(const modint& m) const { return modint(*this) *= m; }
    modint inv() const { return (*this).pow(mod - 2); }
    modint pow(std::uint64_t b) const {
        modint ans = 1, m = modint(*this);
        while (b) {
            if (b & 1) ans *= m;
            m *= m;
            b >>= 1;
        }
        return ans;
    }
};

using Mint = modint<1'000'002'022>; 

int N, M;

vector<int> P;
vector<int> A;

vector<vector<int>> tree;

vector<Mint> weight;
vector<Mint> tot_prod;

struct SQRT {
    struct Block {
        Mint cur_sum;
        Mint tot_sum;
        int rev;

        int l, r;

        void build() {
            tot_sum = 0;
            cur_sum = 0;
            for (int i = l; i <= r; i++) {
                A[i] ^= rev;
                tot_sum += weight[i]; // can only use once
                cur_sum += weight[i] * A[i];
            }
            rev = 0;
        }

        void flip() {
            rev ^= 1;
            cur_sum = tot_sum - cur_sum;
        }
    };

    int BLOCK;
    int NBLOCK;
    vector<Block> blocks;

    SQRT() {;}

    SQRT(int _BLOCK) {
        BLOCK = _BLOCK;
        NBLOCK = (M + BLOCK - 1) / BLOCK;

        blocks.resize(NBLOCK);

        cerr << NBLOCK << " / " << BLOCK << "\n";

        int l = 0, r = 0;
        for (int i = 0; i < NBLOCK; i++) {
            r = min(M - 1, l + BLOCK - 1);
            blocks[i] = Block {0, 0, 0, l, r};
            l += BLOCK;
            blocks[i].build();

            // cerr << blocks[i].l << " - " << blocks[i].r << "\n";
        }
    }

    Mint query(int L, int R) {
        int s = L / BLOCK;
        int e = R / BLOCK;

        Mint ans = 0;

        blocks[s].build();
        if (e != s)
            blocks[e].build();

        for (int i = L; i <= min(blocks[s].r, R); i++) {
            blocks[s].cur_sum -= weight[i] * A[i];
            A[i] ^= 1;
            blocks[s].cur_sum += weight[i] * A[i];


            // cerr << i << " " << A[i] << "\n";
        }

        for (int i = s + 1; i <= e - 1; i++) {
            blocks[i].flip();
        }

        if (e > s) {
            for (int i = blocks[e].l; i <= R; i++) {
                blocks[e].cur_sum -= weight[i] * A[i];
                A[i] ^= 1;
                blocks[e].cur_sum += weight[i] * A[i];
            }
        }

        for (auto & b : blocks)
            ans += b.cur_sum;

        return ans;
    }
};

SQRT ds;

Mint dfs1(int v) {
    if (v >= N) {
        return tot_prod[v] = 1;
    }

    Mint & res = tot_prod[v];
    res = size(tree[v]);
    for (auto u : tree[v]) {
        res *= dfs1(u);
    }
    return res;
}

void dfs2(int v, Mint w) {
    if (v >= N) {
        weight[v - N] = w;
        return;
    }

    int n = size(tree[v]);

    assert(n);

    vector<Mint> suf(n);

    suf[n-1] = tot_prod[tree[v].back()];
    for (int i = n - 2; i >= 0; i--) {
        suf[i] = tot_prod[tree[v][i]] * suf[i+1];
    }    

    Mint pref = 1;
    for (int i = 0; i < n; i++) {
        Mint level = (i + 1 < n ? suf[i+1]: 1);
        level *= pref;
        dfs2(tree[v][i], w * level); 
        pref *= tot_prod[tree[v][i]];
    }
}

void init(int N, int M, std::vector<int> P, std::vector<int> A) {
    ::N = N;
    ::M = M;
    ::P = P;
    ::A = A;

    tree.resize(N);

    for (int i = 1; i < N + M; i++) {
        tree[P[i]].push_back(i);
    }

    tot_prod.resize(N + M);
    dfs1(0);
    weight.resize(M);
    dfs2(0, 1);

    ds = SQRT(400);
}

int count_ways(int L, int R) {
    // naive
    // cerr << "count: " << L << " " << R << "\n";
    L -= N;
    R -= N;

    // for (int i = L; i <= R; i++) {
    //     A[i] ^= 1;
    // }

    // Mint res = 0;
    // for (int i = 0; i < M; i++)
    //     res += weight[i] * Mint(A[i]);

    // return res.get();

    return ds.query(L, R).get();
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...