Submission #439541

#TimeUsernameProblemLanguageResultExecution timeMemory
439541dacin21Keys (IOI21_keys)C++17
67 / 100
2571 ms239748 KiB
#pragma GCC optimize("O3")
#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx") // codeforces
//#pragma GCC target("avx,avx2,fma")
//#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,tune=native") // yandex

#undef _GLIBCXX_DEBUG
#include <bits/stdc++.h>
using namespace std;

uint64_t steps = 0;
uint64_t steps_2 = 0;


template<typename Key, typename Value>
class Mergeable_Map{
    struct Node{
        Node *l, *r;
        Key key;
        Value value;
    };
    static constexpr Key key_min(){
        return numeric_limits<Key>::min();
    }
    static constexpr Key key_max(){
        return numeric_limits<Key>::max();
    }
    static Key key_middle(Key const&l, Key const&r){
        return l + ((r>>1)-(l>>1));
    }

    template<typename Fun>
    void foreach_rec(Node*u, Key a, Key b, Fun const&fun){
        fun((Key const&)u->key, u->value);
        const Key m = key_middle(a, b);
        if(u->l) foreach_rec(u->l, a, m, fun);
        if(u->r) foreach_rec(u->r, m+1, b, fun);
    }
    template<typename Fun>
    Node* merge_rec(Node*u, Node*v, Key a, Key b, Fun const&merge_fun){
        if(!u) return v;
        if(!v) return u;
        //cerr << "merge " << a << " " << b << " " << root_size << "\n";
        const int m = key_middle(a, b);
        if(v->l) u->l = merge_rec(u->l, v->l, a, m, merge_fun);
        if(v->r) u->r = merge_rec(u->r, v->r, m+1, b, merge_fun);
        v->l = v->r = nullptr;
        if(u->key == v->key){
            --root_size;
            merge_fun(u->value, v->value);
            delete v;
            return u;
        }
        if(u->key > v->key){
            swap(u->key, v->key);
            swap(u->value, v->value);
        }
        if(v->key <= m){
            u->l = merge_rec(u->l, v, a, m, merge_fun);
        } else {
            u->r = merge_rec(u->r, v, m+1, b, merge_fun);
        }
        return u;
    }
    template<bool find_only>
    Node* find_or_emplace(Key x, Value v){
        auto construct_new_node = [&]() -> Node*{
            if(find_only) return nullptr;
            ++root_size;
            return new Node{nullptr, nullptr, x, v};
        };
        if(!root){
            return root = construct_new_node();
        }
        Node* ret = nullptr;
        Node*u = root;
        Key a = key_min();
        Key b = key_max();
        for(;;){
            assert(u);
            if(x < u->key){
                if(find_only){
                    return nullptr;
                }
                swap(x, u->key);
                swap(v, u->value);
                if(!ret) ret = u; // we now insert an old value -> save u in ret
            } else if(x == u->key){
                return u;
            }
            const Key m = key_middle(a, b);
            if(x <= m){
                if(!u->l){
                    u->l = construct_new_node();
                    return ret ? ret : u->l;
                }
                b = m;
                u = u->l;
            } else {
                if(!u->r){
                    u->r = construct_new_node();
                    return ret ? ret : u->r;
                }
                a = m+1;
                u = u->r;
            }
        }
    }
    struct merge_with_plus{
        template<typename T>
        void operator()(T &a, T const&b){
            a += b;
        }
    };
    void erase_back_rec(){
        --root_size;
        Node** uu = &root;
        Node*u = root;
        for(;;){
            if(u->r){
                uu = &(u->r);
                u = u->r;
            } else if(u->l){
                uu = &(u->l);
                u = u->l;
            } else {
                delete u;
                *uu = nullptr;
                break;
            }
        }
    }
    void erase_front_rec(){
        --root_size;
        Node** uu = &root;
        Node*u = root;
        for(;;){
            if(u->l){
                swap(u->key, u->l->key);
                swap(u->value, u->l->value);
                uu = &(u->l);
                u = u->l;
            } else if(u->r){
                swap(u->key, u->r->key);
                swap(u->value, u->r->value);
                uu = &(u->r);
                u = u->r;
            } else {
                delete u;
                *uu = nullptr;
                break;
            }
        }
    }
    void erase_rec(Key const&x){
        --root_size;
        Node** uu = &root;
        Node*u = root;
        Key a = key_min();
        Key b = key_max();
        while(u->key != x){
            auto const m = key_middle(a, b);
            if(x <= m){
                uu = &(u->l);
                u = u->l;
                b = m;
            } else {
                uu = &(u->r);
                u = u->r;
                a = m+1;
            }
        }
        for(;;){
            if(u->l){
                swap(u->key, u->l->key);
                swap(u->value, u->l->value);
                uu = &(u->l);
                u = u->l;
            } else if(u->r) {
                swap(u->key, u->r->key);
                swap(u->value, u->r->value);
                uu = &(u->r);
                u = u->r;
            } else {
                delete u;
                *uu = nullptr;
                break;
            }
        }
    }

public:
    bool empty(){
        return !root;
    }
    size_t size(){
        return root_size;
    }

    pair<Key&, Value&> back(){
        assert(!empty());
        Node*u = root;
        for(;;){
            if(u->r) u = u->r;
            else if(u->l) u = u->l;
            else return {u->key, u->value};
        }
    }
    void pop_back(){
        erase_back_rec();
    }
    pair<Key&, Value&> front(){
        assert(!empty());
        return {root->key, root->value};
    }
    void pop_front(){
        erase_front_rec();
    }

    int count(Key const&x){
        Node*u = find_or_emplace<true>(x, Value{});
        return !!u;
    }
    Value& operator[](Key const&x){
        Node*u = find_or_emplace<false>(x, Value{});
        return u->value;
    }
    void erase(Key const&x){
        erase_rec(x);
    }

    template<typename Fun>
    void absorb(Mergeable_Map &o, Fun merge_fun = merge_with_plus{}){
        root_size += o.root_size;
        root = merge_rec(root, o.root, key_min(), key_max(), merge_fun);
        o.root_size = 0;
        o.root = nullptr;
    }
    template<typename Fun>
    void foreach(Fun fun){
        if(root){
            foreach_rec(root, key_min(), key_max(), fun);
        }
    }
private:
    size_t root_size = 0;
    Node *root = nullptr;
};

struct Edge_Compoment{
    int key;
    vector<int> vertices;
};
vector<Edge_Compoment> edge_components;

template<typename T>
void join(vector<T> &a, vector<T> &b){
    if(a.size() < b.size()){
        a.swap(b);
    }
    steps_2 += b.size();
    a.insert(a.end(), b.begin(), b.end());
    b.clear();
}


struct Component{
    vector<int> vertices;
    unordered_set<int> keys;
    Mergeable_Map<int, int> edge_pos;
    Mergeable_Map<int, vector<int> > reachable_edges;
    Mergeable_Map<int, vector<int> > unreachable_edges;
    //map<int, vector<int> > unreachable_edges;


    template<typename T>
    bool search(T callback){
        while(!reachable_edges.empty()){
            auto it = reachable_edges.front();
            while(!it.second.empty()){
                auto&p = it.second.back();
                auto &pos = edge_pos[p];
                auto&vv = edge_components[p].vertices;
                while(pos < (int)vv.size()){
                    if(callback(vv[pos])) return true;
                    ++pos;
                }
                it.second.pop_back();
            }
            reachable_edges.pop_front();
        }
        return false;
    }
    void add_key(int c){
        if(!keys.count(c)){
            keys.insert(c);
            //cerr << "add key " << c << "\n";
            if(unreachable_edges.count(c)){
                auto &v = reachable_edges[c];
                auto &w = unreachable_edges[c];
                //cerr << v.size() << " " << w.size() << "\n";
                join(v, w);
                unreachable_edges.erase(c);
            }
        }
    }
    size_t size(){
        return reachable_edges.size() + unreachable_edges.size() + keys.size() /*+edge_pos.size()*/;
    }
    void absorb(Component &o){
        join(vertices, o.vertices);
        reachable_edges.absorb(o.reachable_edges, join<int>);
        //cerr << "absorb " << unreachable_edges.size() << " " << reachable_edges.size() << " " << o.unreachable_edges.size() << "\n";
        /*for(auto &e:o.unreachable_edges){
            ++steps;
            if(keys.count(e.first)){
                join(reachable_edges[e.first], e.second);
            } else {
                join(unreachable_edges[e.first], e.second);
            }
        }*/
        /*o.unreachable_edges.foreach([&](auto const&first, auto &second){
            ++steps;
            if(keys.count(first)){
                join(reachable_edges[first], second);
            } else {
                join(unreachable_edges[first], second);
            }
        });*/
        o.unreachable_edges.foreach([&](auto const&first, auto &second){
            ++steps;
            if(keys.count(first)){
                join(reachable_edges[first], second);
                second.clear();
            }
        });
        unreachable_edges.absorb(o.unreachable_edges, join<int>);
        for(auto &e:o.keys){
            ++steps;
            add_key(e);
        }
        //cerr << " -> " << unreachable_edges.size() << " " << reachable_edges.size() << "\n";
        edge_pos.absorb(o.edge_pos, [&](auto &a, auto const&b){ a = max(a, b); });
    }
};

struct DSU{
    DSU(int n_) : n(n_), p(n, -1), comps(n){

    }
    int f(int x){
        ++steps_2;
        return p[x] < 0 ? x : p[x] = f(p[x]);
    }
    Component& c(int x){
        return comps[f(x)];
    }
    bool u(int a, int b){
        ++steps;
        a = f(a);
        b = f(b);
        if(a == b) return false;
        if(comps[a].size() < comps[b].size()){
            swap(a, b);
        }
        comps[a].absorb(comps[b]);
        p[b] = a;
        return true;
    }

    int n;
    vector<int> p;
    vector<Component> comps;

};

vector<int> find_reachable(vector<int> r, vector<int> u, vector<int> v, vector<int> c) {
    #ifdef LOCAL_RUN
    auto time_a = chrono::high_resolution_clock::now();
    #endif
    const int m = u.size();
    const int n = r.size();
    unordered_map<int, vector<int> > c_inv;
    c_inv.reserve(1<<10);
    c_inv.max_load_factor(0.25);
    for(int i=0; i<m; ++i){
        c_inv[c[i]].push_back(i);
    }

    DSU uni(n);
    // find edge components
    edge_components.clear();
    vector<vector<int> > g(n);
    vector<int> vis(n, -1);
    int tim = -1; // = index of last edge_component

    vector<int> vert;
    vert.reserve(n);
    for(auto const&e:c_inv){
        for(auto const&f:e.second){
            g[u[f]].push_back(v[f]);
            g[v[f]].push_back(u[f]);
        }
        const int t0 = tim;
        auto rec = [&](auto rec, int u){
            ++steps;
            if(vis[u] == tim) return;
            vis[u] = tim;
            //cerr << "add unreachable " << u << " " << e.first << "\n";
            uni.comps[u].unreachable_edges[e.first].push_back(tim);
            //edge_components.back().vertices.push_back(u);
            vert.push_back(u);
            for(auto const&e:g[u]){
                rec(rec, e);
            }
        };
        for(auto const&f:e.second){
            ++steps_2;
            if(vis[u[f]] <= t0){
                ++tim;
                vert.clear();
                edge_components.push_back(Edge_Compoment{e.first});
                rec(rec, u[f]);
                edge_components.back().vertices = vert;
            }
        }
        for(auto const&f:e.second){
            g[u[f]].pop_back();
            g[v[f]].pop_back();
        }
    }
    #ifdef LOCAL_RUN
    auto time_b = chrono::high_resolution_clock::now();
    cerr << "tmp: " << chrono::duration_cast<chrono::nanoseconds>(time_b - time_a).count()*1e-9 << "\n";
    #endif

    for(int i=0; i<n; ++i){
        uni.c(i).add_key(r[i]);
        uni.c(i).vertices.push_back(i);
    }
    #ifdef LOCAL_RUN
    auto time_c = chrono::high_resolution_clock::now();
    cerr << "go search: " << chrono::duration_cast<chrono::nanoseconds>(time_c - time_b).count()*1e-9 << "\n";;
    #endif
    pair<int, vector<int> > out(n+1, {});
    // run search
    vis.assign(n, 0);
    tim = 1;
    vector<int> to(n, -1);
    vector<int> prev(n, -1);
    for(int i=0; i<n; ++i){
        int a = uni.f(i);
        ++tim;
        if(vis[a] == 0){
            vis[a] = tim;
            prev[a] = -1;
            // find out edge
            for(;;){
                auto res = uni.comps[a].search([&](int const&v){
                    //cerr << " => " << v << "\n";
                    const int b = uni.f(v);
                    if(b == a) return false; // loop
                    if(vis[b] == 0){
                        //cerr << "move to: " << b << "\n";
                        to[a] = b;
                        prev[b] = a;
                        vis[b] = tim;
                        // continue search with b
                        a = b;
                        return true;
                    }
                    if(vis[b] == tim){
                        //cerr << "cycle " << a;
                        // extract cycle
                        int cnt = 0;
                        for(int c = b; c != a; c = to[c]){
                            //cerr << " " << c;
                            uni.u(a, c);
                            ++cnt;
                        }
                        //cerr << "cycle: " << cnt << "\n";
                        const int aa = uni.f(a);
                        if(prev[b] != -1){
                            assert(to[prev[b]] == b);
                            to[prev[b]] = aa;
                        }
                        prev[aa] = prev[b];
                        a = aa;
                        //cerr << " -> " << a << "\n";
                        to[a] = -1;
                        assert(vis[a] == tim);
                        return true;
                    }
                    // found old vertices -> finish
                    //cerr << "old " << b << "\n";
                    a = -1;
                    return true;
                });
                if(a == -1) break; // found old vertices
                if(!res){
                    auto const&v = uni.comps[a].vertices;
                    //cerr << v.size() << " : ";
                    //for(auto &e : v) cerr << e << " ";
                    //cerr << "\n";
                    if((int)v.size() < out.first){
                        out.first = v.size();
                        out.second.clear();
                    }
                    if((int)v.size() == out.first){
                        out.second.insert(out.second.end(), v.begin(), v.end());
                    }
                    break;
                }
            }

        }
    }
    vector<int> ret(n);
    for(auto const&e:out.second){
        ret[e] = 1;
    }
    #ifdef LOCAL_RUN
    auto time_d = chrono::high_resolution_clock::now();
    cerr << "done: " << chrono::duration_cast<chrono::nanoseconds>(time_d - time_c).count()*1e-9 << "\n";;
    cerr << steps << "\n";
    cerr << steps_2 << "\n";
    #endif
    return ret;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...