답안 #1078612

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
1078612 2024-08-27T22:52:38 Z ProtonDecay314 늑대인간 (IOI18_werewolf) C++17
0 / 100
4000 ms 261064 KB
#include "werewolf.h"
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef vector<ll> vll;
typedef vector<vll> vvll;
typedef vector<int> vi;
typedef vector<vi> vvi;
typedef pair<int, int> pi;
typedef pair<ll, ll> pll;
typedef vector<pi> vpi;
typedef vector<pll> vpll;
typedef vector<vpi> vvpi;
typedef vector<vpll> vvpll;
typedef vector<bool> vb;
typedef vector<vb> vvb;
typedef short int si;
typedef vector<si> vsi;
typedef vector<vsi> vvsi;
#define IOS ios_base::sync_with_stdio(false); cin.tie(nullptr); cout.tie(nullptr);
#define L(varll, mn, mx) for(ll varll = (mn); varll < (mx); varll++)
#define LR(varll, mx, mn) for(ll varll = (mx); varll > (mn); varll--)
#define LI(vari, mn, mx) for(int vari = (mn); vari < (mx); vari++)
#define LIR(vari, mx, mn) for(int vari = (mx); vari > (mn); vari--)
#define INPV(varvec) for(auto& varveci : (varvec)) cin >> varveci
#define fi first
#define se second
#define pb push_back
#define INF(type) numeric_limits<type>::max()
#define NINF(type) numeric_limits<type>::min()
#define TCASES int t; cin >> t; while(t--)

/*
I need to create
1. A regular union find with an extra augmentation:
each parent node optionally points to an internal node of a union find tree
2. A union find tree, with each node initialized with null values for smallest, largest, and time unified. It must also have 18 binary jump pointers per node
3. A merge sort tree, with each interval storing points within it

It might also help to store the [vL, vR] values per node. Could be done with a vpi tbh

Overall complexity is... O((n + q) log^2 n)
*/
const int MAX_JUMP_PTR = 18;

class UfTree {
    public:
        typedef vector<UfTree*> vutp;
        int l, r; // smallest and largest values
        UfTree *lt, *rt;
        int t; // unification time
        int orig_ind;
        vutp up;

        UfTree(int a_t, UfTree* a_lt, UfTree* a_rt, int a_orig_ind): t(a_t), lt(a_lt), rt(a_rt), orig_ind(a_orig_ind), up(MAX_JUMP_PTR, nullptr) {};

        void postorder_traverse(int low_new_ind, UfTree* par, vi& ind, vutp& leaf_ptrs) {
            up[0] = par;
            if(lt == nullptr) {
                // Compute new index
                l = r = low_new_ind;

                ind[orig_ind] = l;
                leaf_ptrs[orig_ind] = this;
                return;
            }

            lt->postorder_traverse(low_new_ind, this, ind, leaf_ptrs);
            rt->postorder_traverse(lt->r + 1, this, ind, leaf_ptrs); // ! check if lt->r + 1 is right

            l = lt->l;
            r = rt->r;
        }

        void compute_jump_pointers() {
            for(int i = 1; i < MAX_JUMP_PTR; i++) {
                up[i] = up[i - 1]->up[i - 1];
            }

            if(lt == nullptr) return;

            lt->compute_jump_pointers();
            rt->compute_jump_pointers();
        }

        UfTree* find_most_recent_node(int qt) {
            UfTree* cur = this;

            int cur_jump_sz = MAX_JUMP_PTR - 1;

            while(cur_jump_sz >= 0) {
                if(cur->up[cur_jump_sz]->t > qt) {
                    // The creation time of the next node is later than the query time. Reduce jump size
                    cur_jump_sz--;
                } else {
                    cur = cur->up[cur_jump_sz];
                }
            }

            return cur;
        }
};

class Uf {
    public:
        typedef vector<UfTree*> vutp;
        int n;
        vutp uf_tree_node;
        vi par;
        vi csize;
        int ncomps;

        Uf(int a_n): n(a_n), uf_tree_node(n, nullptr), par(n, 0), csize(n, 1), ncomps(a_n) {
            for(int i = 0; i < n; i++) {
                uf_tree_node[i] = new UfTree(0, nullptr, nullptr, i);
                par[i] = i;
            }
        }

        int find(int i) {
            return (i == par[i] ? i : par[i] = find(par[i])); // ! warning, might be wrong, I don't usually implement it this way
        }

        int conn(int i, int j) {
            return find(i) == find(j);
        }

        void unify(int i, int j, int t) {
            int pari = find(i), parj = find(j);

            if(pari == parj) return;

            UfTree* new_uftree_node = new UfTree(t, uf_tree_node[parj], uf_tree_node[pari], -1);

            if(csize[pari] < csize[parj]) {
                uf_tree_node[parj] = new_uftree_node;
                par[pari] = parj;
                csize[parj] += csize[pari];
            } else {
                uf_tree_node[pari] = new_uftree_node;
                par[parj] = pari;
                csize[pari] += csize[parj];
            }

            ncomps++;
        }
};

// Merge sort tree
class Tree {
    public:
        int l, r;
        Tree *lt, *rt;
        vi y;

        Tree(int a_l, int a_r): l(a_l), r(a_r), lt(nullptr), rt(nullptr), y() {};

        void combine() {
            int lp = 0;
            int rp = 0;
            int ls = lt->y.size();
            int rs = rt->y.size();

            while(lp < ls && rp < rs) {
                if(lt->y[lp] < rt->y[rp]) {
                    y.pb(lt->y[lp]);
                    lp++;
                } else {
                    y.pb(rt->y[rp]);
                    rp++;
                }
            }

            while(lp < ls) {
                y.pb(lt->y[lp]);
                lp++;
            }

            while(rp < rs) {
                y.pb(rt->y[rp]);
                rp++;
            }
        }

        void build(const vpi& a) {
            if(l == r) {
                y.pb(a[l].se);
                return;
            }

            int m = (l + r) >> 1;
            lt = new Tree(l, m);
            rt = new Tree(m + 1, r);

            lt->build(a);
            rt->build(a);
            
            combine();
        }

        bool qry(int qvll, int qvlr, int qvrl, int qvrr) {
            if(qvll > r || qvlr < l) return false;

            if(qvll == l && qvlr == r) {
                // Check if the current y includes something in the range [qvrl, qvrr]
                int l = -1, r = y.size();

                while(r - l > 1) {
                    int m = (l + r) >> 1;

                    if(y[m] >= qvrl) r = m;
                    else l = m;
                }

                if(r == y.size()) return false;
                return y[r] <= qvrr;
            }

            int m = (l + r) >> 1;

            return lt->qry(qvll, min(qvlr, m), qvrl, qvrr) || rt->qry(max(qvll, m + 1), qvlr, qvrl, qvrr);
        }
};

vi check_validity(int n, vi x, vi y, vi s, vi e, vi l, vi r) {
    int q = s.size(), m = x.size();
    
    // Getting edge list representation
    vpi el;

    for(int i = 0; i < m; i++) {
        el.pb({x[i], y[i]});
    }

    // Stores the vL, vR values
    typedef vector<UfTree*> vutp;
    vi vl(n, -1);
    vi vr(n, -1);
    vutp vlp(n, nullptr);
    vutp vrp(n, nullptr);

    // Sort by increasing max element
    sort(el.begin(), el.end(), [](pi e1, pi e2) {return max(e1.fi, e1.se) < max(e2.fi, e2.se);});

    // Process R
    int e_ptr = 0; // Pointer to the current edge

    Uf ufr(n);

    for(int i = 0; i < n; i++) {
        while(e_ptr < m && max(el[e_ptr].fi, el[e_ptr].se) == i) {
            ufr.unify(el[e_ptr].fi, el[e_ptr].se, i);
            e_ptr++;
        }
    }

    // Sort by decreasing min element
    sort(el.begin(), el.end(), [](pi e1, pi e2) {return min(e1.fi, e1.se) > min(e2.fi, e2.se);});

    // Process L
    e_ptr = 0; // Pointer to the current edge
    Uf ufl(n);

    for(int i = 0; i < n; i++) {
        while(e_ptr < m && min(el[e_ptr].fi, el[e_ptr].se) == n - i - 1) {
            ufl.unify(el[e_ptr].fi, el[e_ptr].se, -(n - i - 1));
            e_ptr++;
        }
    }

    /*
    For both left and right union find trees:
    1. Perform reindexing
    2. Compute the l and r values per node too
    3. Finally, compute the parent jump pointer. This will be needed for the computation of all jump pointers later
    */
    UfTree* luft = ufl.uf_tree_node[ufl.find(0)];
    UfTree* ruft = ufr.uf_tree_node[ufr.find(0)];

    luft->postorder_traverse(0, luft, vl, vlp);
    ruft->postorder_traverse(0, ruft, vr, vrp);

    for(int i = 0; i < n; i++) {
        vlp[i]->t = -(n - 1);
    }

    // Compute binary jump pointers for both left and right union find trees
    // Use PREORDER (we want higher nodes to be processed before lower nodes)
    luft->compute_jump_pointers();
    ruft->compute_jump_pointers();

    // Set-up the merge sort tree. Initialize it with the [vl, vr] values, vl on the x-axis, vr on the y-axis
    vpi vlvr(n, {-1, -1});
    for(int i = 0; i < n; i++) {
        vlvr[i] = {vl[i], vr[i]};
        // cout << vl[i] << " " << vr[i] << endl;
    }

    sort(vlvr.begin(), vlvr.end());

    Tree tr(0, n - 1);
    tr.build(vlvr);

    // Query from the merge sort tree
    vi ans(q, false);
    for(int i = 0; i < q; i++) {
        UfTree* ln = vlp[s[i]]->find_most_recent_node(-l[i]);
        UfTree* rn = vrp[e[i]]->find_most_recent_node(r[i]);
        // cout << ln->l << " " << ln->r << " " << rn->l << " " << rn->r << endl;
        ans[i] = tr.qry(ln->l, ln->r, rn->l, rn->r) ? 1 : 0;
    }

    // Congrats bro, you just ACed the penultimate IOI problem :))
    return ans;
}

Compilation message

werewolf.cpp: In constructor 'UfTree::UfTree(int, UfTree*, UfTree*, int)':
werewolf.cpp:51:13: warning: 'UfTree::t' will be initialized after [-Wreorder]
   51 |         int t; // unification time
      |             ^
werewolf.cpp:50:17: warning:   'UfTree* UfTree::lt' [-Wreorder]
   50 |         UfTree *lt, *rt;
      |                 ^~
werewolf.cpp:55:9: warning:   when initialized here [-Wreorder]
   55 |         UfTree(int a_t, UfTree* a_lt, UfTree* a_rt, int a_orig_ind): t(a_t), lt(a_lt), rt(a_rt), orig_ind(a_orig_ind), up(MAX_JUMP_PTR, nullptr) {};
      |         ^~~~~~
werewolf.cpp: In member function 'bool Tree::qry(int, int, int, int)':
werewolf.cpp:215:22: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  215 |                 if(r == y.size()) return false;
      |                    ~~^~~~~~~~~~~
# 결과 실행 시간 메모리 Grader output
1 Execution timed out 4059 ms 348 KB Time limit exceeded
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Execution timed out 4059 ms 348 KB Time limit exceeded
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Execution timed out 4088 ms 261064 KB Time limit exceeded
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Execution timed out 4059 ms 348 KB Time limit exceeded
2 Halted 0 ms 0 KB -