Submission #1054478

#TimeUsernameProblemLanguageResultExecution timeMemory
1054478shmaxDigital Circuit (IOI22_circuit)C++17
100 / 100
684 ms43088 KiB
#include "circuit.h"
#include <bits/stdc++.h>

using namespace std;
using i32 = int;
#define int long long
#define len(x) ((int)(x.size()))
template<typename T>
using vec = vector<T>;

template<typename T>
using graph = vec<vec<T>>;

vec<int> cur;
graph<int> g;
int n, m;


#define mod 1000002022

int sum(int a, int b) {
    return (a + b) - (a + b >= mod ? mod : 0);
}

int sub(int a, int b) {
    return (a - b + mod) - ((a - b + mod) >= mod ? mod : 0);
}

int mult(int a, int b) {
    return (a * b) % mod;
}

int bin_pow(int a, int b) {
    int r = 1;
    while (b) {
        if (b & 1)
            r = mult(r, a);
        a = mult(a, a);
        b >>= 1;
    }
    return r;
}


const int maxN = 1e5 + 5;


struct vertex {
    int length;
    int cur_sum;
};

vertex stree[maxN * 4];
bool pushes[maxN * 4];

void init() {
    memset(stree, 0, sizeof(stree));
    memset(pushes, 0, sizeof(pushes));
}

void push(int v, int tl, int tr) {
    if (!pushes[v]) return;
    if (tl != tr) {
        pushes[v * 2] ^= 1;
        pushes[v * 2 + 1] ^= 1;
        stree[v * 2].cur_sum = sub(stree[v * 2].length, stree[v * 2].cur_sum);
        stree[v * 2 + 1].cur_sum = sub(stree[v * 2 + 1].length, stree[v * 2 + 1].cur_sum);
    }
    pushes[v] = false;
}

void build(int v, int tl, int tr, vec<int> &length) {
    if (tl == tr) {
        stree[v].length = length[tl];
        stree[v].cur_sum = 0;
        return;
    }
    int tm = (tl + tr) / 2;
    build(v * 2, tl, tm, length);
    build(v * 2 + 1, tm + 1, tr, length);
    stree[v].length = sum(stree[v * 2].length, stree[v * 2 + 1].length);
    stree[v].cur_sum = 0;
}


void update(int v, int tl, int tr, int l, int r) {
    push(v, tl, tr);
    if (tl == l and tr == r) {
        stree[v].cur_sum = sub(stree[v].length, stree[v].cur_sum);
        pushes[v] ^= 1;
        push(v, tl, tr);
        return;
    }
    int tm = (tl + tr) / 2;
    if (r <= tm) {
        update(v * 2, tl, tm, l, r);
    } else if (l > tm) {
        update(v * 2 + 1, tm + 1, tr, l, r);
    } else {
        update(v * 2, tl, tm, l, tm);
        update(v * 2 + 1, tm + 1, tr, tm + 1, r);
    }
    stree[v].cur_sum = sum(stree[v * 2].cur_sum, stree[v * 2 + 1].cur_sum);
    stree[v].length = sum(stree[v * 2].length, stree[v * 2 + 1].length);
}


int get(int v, int tl, int tr, int l, int r) {
    push(v, tl, tr);
    if (tl == l and tr == r) {
        return stree[v].cur_sum;
    }
    int tm = (tl + tr) / 2;
    if (r <= tm) {
        return get(v * 2, tl, tm, l, r);
    } else if (l > tm) {
        return get(v * 2 + 1, tm + 1, tr, l, r);
    } else {
        return sum(get(v * 2, tl, tm, l, tm), get(v * 2 + 1, tm + 1, tr, tm + 1, r));
    }
}

void init(i32 N, i32 M, std::vector<i32> P, std::vector<i32> A) {
    n = N;
    m = M;
    g.resize(n + m);
    for (int i = 1; i < len(P); i++) {
        g[P[i]].push_back(i);
    }
    for (int i = 0; i < m; i++) {
        cur.push_back(A[i]);
    }
    int depth = 0;
    vec<int> d(n + m, 0);
    function<void(int)> dfs = [&](int v) {
        d[v] = depth;
        for (auto u: g[v]) {
            depth++;
            dfs(u);
            depth--;
        }
    };
    dfs(0);
    init();
    vec<int> cont(n + m, 1);
    function<void(int)> dfs1 = [&](int v) {
        cont[v] = len(g[v]);
        for (auto u: g[v]) {
            if (u < n) {
                dfs1(u);
                cont[v] = mult(cont[v], cont[u]);
            }
        }
    };
    dfs1(0);


    vec<int> dists(m);

    function<void(int, int)> fill_dist = [&](int v, int cur = 1) {
        if (v >= n) {
            dists[v - n] = cur;
            return;
        }

        vec<int> pref(len(g[v]) + 1, 1);
        for (int i = 0; i < len(g[v]); i++) {
            pref[i + 1] = mult(pref[i], cont[g[v][i]]);
        }
        vec<int> suff(len(g[v]) + 1, 1);
        for (int i = len(g[v]) - 1; i >= 0; i--) {
            suff[i] = mult(suff[i + 1], cont[g[v][i]]);
        }
        for (int i = 0; i < len(g[v]); i++) {
            int u = g[v][i];
            fill_dist(u, mult(cur, mult(pref[i], suff[i + 1])));
        }
    };
    fill_dist(0,1);

    /*
    for (int i = n; i < m + n; i++) {
        vec<bool> mark(n + m, false);
        int cur_v = i;
        while (cur_v != 0) {
            mark[cur_v] = true;
            cur_v = P[cur_v];
        }
        mark[0] = true;
        int contrib = 1;
        for (int j = 0; j < n; j++) {
            if (!mark[j]) {
                contrib = mult(contrib, len(g[j]));
            }
        }
        dists[i - n] = contrib;
    }*/
    build(1, 0, m - 1, dists);
    for (int i = 0; i < m; i++) {
        if (cur[i] == 1) {
            update(1, 0, m - 1, i, i);
        }
    }

}

i32 count_ways(i32 L, i32 R) {
    L -= n;
    R -= n;
    update(1, 0, m - 1, L, R);
    int cur_sum = get(1, 0, m - 1, 0, m - 1);
    return cur_sum;
}
//422283126
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...