Submission #628028

#TimeUsernameProblemLanguageResultExecution timeMemory
628028HanksburgerDigital Circuit (IOI22_circuit)C++17
100 / 100
1275 ms35568 KiB
#include "circuit.h"
#include <bits/stdc++.h>
using namespace std;
const long long mod=1e9+2022;
long long seg[800005], sum[800005], lazy[800005], state[200005], dp[200005], a[200005], n, m;
vector<long long> adj[200005];
void dfs(long long u)
{
    if (u>=n)
    {
        dp[u]=1;
        return;
    }
    dp[u]=adj[u].size();
    for (long long v:adj[u])
    {
        dfs(v);
        dp[u]=(dp[u]*dp[v])%mod;
    }
}
void dfs2(long long u, long long x)
{
    if (u>=n)
    {
        a[u]=x;
        return;
    }
    vector<long long> vec;
    long long sq=sqrt(adj[u].size());
    for (long long i=0; i<adj[u].size(); i++)
    {
        long long v=adj[u][i];
        if (i%sq==0)
            vec.push_back(dp[v]);
        else
            vec[vec.size()-1]=(vec[vec.size()-1]*dp[v])%mod;
    }
//    cout << "vec: ";
//    for (long long i=0; i<vec.size(); i++)
//        cout << vec[i] << ' ';
//    cout << '\n';
    for (long long i=0; i<adj[u].size(); i++)
    {
        long long v=adj[u][i], y=x;
        for (long long j=0; j<i/sq; j++)
            y=(y*vec[j])%mod;
        for (long long j=i/sq+1; j<vec.size(); j++)
            y=(y*vec[j])%mod;
        for (long long j=i/sq*sq; j<i; j++)
            y=(y*dp[adj[u][j]])%mod;
        for (long long j=i+1; j<min((long long)adj[u].size(), (i/sq+1)*sq); j++)
            y=(y*dp[adj[u][j]])%mod;
        dfs2(v, y);
    }
}
void push(long long i, long long l, long long r)
{
    if (lazy[i])
    {
        seg[i]=(sum[i]-seg[i]+mod)%mod;
        lazy[i]=0;
        lazy[i*2]^=1;
        lazy[i*2+1]^=1;
    }
}
void build(long long i, long long l, long long r)
{
    if (l==r)
    {
        sum[i]=a[l];
        seg[i]=a[l]*state[l];
//        cout << "l r seg " << l << ' ' << r << ' ' << seg[i] << '\n';
        return;
    }
    long long mid=(l+r)/2;
    build(i*2, l, mid);
    build(i*2+1, mid+1, r);
    sum[i]=(sum[i*2]+sum[i*2+1])%mod;
    seg[i]=(seg[i*2]+seg[i*2+1])%mod;
//    cout << "l r seg " << l << ' ' << r << ' ' << seg[i] << '\n';
}
void update(long long i, long long l, long long r, long long ql, long long qr)
{
    push(i, l, r);
    if (ql<=l && r<=qr)
    {
        lazy[i]=1;
        push(i, l, r);
//        cout << "l r seg " << l << ' ' << r << ' ' << seg[i] << '\n';
        return;
    }
    long long mid=(l+r)/2;
    push(i*2, l, mid);
    push(i*2+1, mid+1, r);
    if (l<=qr && ql<=mid)
        update(i*2, l, mid, ql, qr);
    if (mid+1<=qr && ql<=r)
        update(i*2+1, mid+1, r, ql, qr);
    seg[i]=(seg[i*2]+seg[i*2+1])%mod;
//    cout << "l r seg " << l << ' ' << r << ' ' << seg[i] << '\n';
}
void init(int N, int M, vector<int> P, vector<int> A)
{
    n=N;
    m=M;
    for (long long i=1; i<n+m; i++)
        adj[P[i]].push_back(i);
    dfs(0);
    dfs2(0, 1);
//    cout << "a: ";
//    for (long long i=0; i<m; i++)
//        cout << a[i+n] << ' ';
//    cout << '\n';
    for (long long i=0; i<m; i++)
        state[i+n]=A[i];
    build(1, n, n+m-1);
}
int count_ways(int L, int R)
{
    update(1, n, n+m-1, L, R);
    return seg[1];
}

Compilation message (stderr)

circuit.cpp: In function 'void dfs2(long long int, long long int)':
circuit.cpp:30:26: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<long long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
   30 |     for (long long i=0; i<adj[u].size(); i++)
      |                         ~^~~~~~~~~~~~~~
circuit.cpp:42:26: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<long long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
   42 |     for (long long i=0; i<adj[u].size(); i++)
      |                         ~^~~~~~~~~~~~~~
circuit.cpp:47:35: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<long long int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
   47 |         for (long long j=i/sq+1; j<vec.size(); j++)
      |                                  ~^~~~~~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...