Submission #1346086

#TimeUsernameProblemLanguageResultExecution timeMemory
1346086ZicrusHieroglyphs (IOI24_hieroglyphs)C++20
28 / 100
262 ms45536 KiB
#include "hieroglyphs.h"
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef long double ld;
typedef pair<ll, ll> pll;
typedef pair<int, int> pii;
typedef vector<int> vi;
#define rep(i, a, b) for(int i = a; i < (b); ++i)
#define all(x) x.begin(), x.end()
#define sz(x) (ll)(x).size()
constexpr ll inf = 4e18;
mt19937 mt(time(0));

vector<ll> imp() {
    return {-1};
}

struct segtree {
    ll nt = 0;
    vector<ll> tree;

    segtree(ll n) {
        nt = 1;
        while (nt < n) nt *= 2;
        tree = vector<ll>(2*nt, inf);
    }

    void point_set(ll i, ll v) { return point_set(1, 0, nt-1, i, v); }
    void point_set(ll k, ll tl, ll tr, ll i, ll v) {
        if (tl == tr) return tree[k] = v, void();
        ll mid = tl + (tr - tl) / 2;
        if (i <= mid) point_set(2*k, tl, mid, i, v);
        else point_set(2*k+1, mid+1, tr, i, v);
        tree[k] = min(tree[2*k], tree[2*k+1]);
    }

    ll range_min(ll l, ll r) { return range_min(1, 0, nt-1, l, r); }
    ll range_min(ll k, ll tl, ll tr, ll l, ll r) {
        if (r < tl || l > tr) return inf;
        if (l <= tl && r >= tr) return tree[k];
        ll mid = tl + (tr - tl) / 2;
        return min(range_min(2*k, tl, mid, l, r), range_min(2*k+1, mid+1, tr, l, r));
    }
};

struct symbol {
    ll c = -1;
    ll l0 = inf, r0 = -inf, l1 = inf, r1 = -inf;

    symbol(ll c, ll l0, ll r0, ll l1, ll r1) : c(c), l0(l0), r0(r0), l1(l1), r1(r1) { }

    bool operator<(const symbol &o) const {
        return l0 < o.r0 && l1 < o.r1;
    }
};

vector<symbol> get_symbols(vector<ll> &a, vector<ll> &b) {
    ll n = sz(a), m = sz(b);
    vector<vector<ll>> ids0(2e5+1), ids1(2e5+1);
    for (ll i = 0; i < n; i++) ids0[a[i]].push_back(i);
    for (ll i = 0; i < m; i++) ids1[b[i]].push_back(i);
    
    vector<symbol> syms;
    for (ll i = 0; i <= 2e5; i++) {
        if (sz(ids0[i]) < sz(ids1[i])) {
            ll diff = sz(ids1[i]) - sz(ids0[i]);
            for (ll j = 0; j < sz(ids0[i]); j++) {
                syms.emplace_back(i, ids0[i][j], ids0[i][j], ids1[i][j], ids1[i][j+diff]);
            }
        }
        else {
            ll diff = sz(ids0[i]) - sz(ids1[i]);
            for (ll j = 0; j < sz(ids1[i]); j++) {
                syms.emplace_back(i, ids0[i][j], ids0[i][j+diff], ids1[i][j], ids1[i][j]);
            }
        }
    }
    return syms;
}

bool symbols_valid(ll n, ll m, vector<symbol> &syms) {
    segtree mnl0(n), mxr0(n), mnl1(m), mxr1(m);
    segtree mnr0(n), mxl0(n), mnr1(m), mxl1(m);
    for (auto &s : syms) {
        if (s.l0 == s.r0) {
            mnl0.point_set(s.l0, s.l1);
            mxr0.point_set(s.l0, -s.r1);
            mnr0.point_set(s.l0, s.r1);
            mxl0.point_set(s.l0, -s.l1);
        }
        if (s.l1 == s.r1) {
            mnl1.point_set(s.l1, s.l0);
            mxr1.point_set(s.l1, -s.r0);
            mnr1.point_set(s.l1, s.r0);
            mxl1.point_set(s.l1, -s.l0);
        }
    }

    for (auto &s : syms) {
        if (mnr0.range_min(s.r0, inf) < s.l1) return false;
        if (-mxl0.range_min(0, s.l0) > s.r1) return false;
        if (mnr1.range_min(s.r1, inf) < s.l0) return false;
        if (-mxl1.range_min(0, s.l1) > s.r0) return false;
        bool ls = false, gr = false;
        if (s.l0 < s.r0) {
            ls = mnl0.range_min(s.l0, s.r0) < s.l1;
            gr = -mxr0.range_min(s.l0, s.r0) > s.l1;
        }
        else if (s.l1 < s.r1) {
            ls = mnl1.range_min(s.l1, s.r1) < s.l0;
            gr = -mxr1.range_min(s.l1, s.r1) > s.l0;
        }
        if (ls && gr) return false;
    }
    return true;
}

vector<ll> eq_prev(vector<ll> &a) {
    ll n = sz(a);
    vector<ll> last_seen(2e5+1, -1);
    vector<ll> prev(n);
    for (ll i = 0; i < n; i++) {
        prev[i] = last_seen[a[i]];
        last_seen[a[i]] = i;
    }
    return prev;
}

vector<set<ll>> val_sets(vector<ll> &a) {
    ll n = sz(a);
    vector<set<ll>> sets(2e5+1);
    for (ll i = 0; i < n; i++) {
        sets[a[i]].insert(i);
    }
    return sets;
}

ll strict_next(vector<set<ll>> &sets, ll i, ll v) {
    auto it = sets[v].upper_bound(i);
    return (it == sets[v].end()) ? inf : *it;
}

bool matching_valid(vector<ll> &a, vector<ll> &b, vector<ll> &res, vector<ll> &upper, vector<ll> &lower) {
    ll k = sz(res);
    vector<set<ll>> a_sets = val_sets(a);
    vector<set<ll>> b_sets = val_sets(b);
    vector<ll> a_prev = eq_prev(a);
    vector<ll> b_prev = eq_prev(b);

    segtree dp(k);
    vector<ll> last_seen(2e5+1, -1);
    for (ll i = 0; i < k; i++) { // TODO - handle -1s
        ll pa = a_prev[upper[i]];
        ll pb = b_prev[lower[i]];
        ll pu = last_seen[res[i]];
        ll pba = lower_bound(all(upper), pa) - upper.begin() - 1;
        if (dp.range_min(pu, pba) < pb)
            return false;
        ll v = dp.range_min(pu, i-1);
        if (pu == -1) v = -1;
        dp.point_set(i, strict_next(b_sets, v, res[i]));
        last_seen[res[i]] = i;
    }

    for (ll s = 0; s <= 2e5; s++) {
        if (last_seen[s] == -1) continue;
        ll la = *--a_sets[s].end();
        ll lb = *--b_sets[s].end();
        ll lk = last_seen[s];
        ll pba = lower_bound(all(upper), la) - upper.begin() - 1;
        if (dp.range_min(lk, pba) < lb)
            return false;
    }
    return true;
}

vector<ll> solve(vector<ll> &a, vector<ll> &b) {
    ll n = sz(a), m = sz(b);
    vector<symbol> syms = get_symbols(a, b);
    if (!symbols_valid(n, m, syms)) return imp();
    sort(all(syms)); ll k = sz(syms);

    ll ai = 0, bi = 0;
    vector<ll> last_seen(2e5+1, -1);
    vector<ll> res(k), upper(k), lower(k), u_prev(k);
    for (ll i = 0; i < k; i++) {
        res[i] = syms[i].c;
        while (ai < n && a[ai] != res[i]) ai++;
        while (bi < m && b[bi] != res[i]) bi++;
        upper[i] = ai, lower[i] = bi;
        if (ai++ >= n || bi++ >= m) return imp();
        u_prev[i] = last_seen[res[i]];
        last_seen[res[i]] = i;
    }
    
    if (!matching_valid(a, b, res, upper, lower)) return imp();
    if (!matching_valid(b, a, res, lower, upper)) return imp();
    return res;
}

vector<int> ucs(vector<int> a, vector<int> b) {
    ll n = sz(a), m = sz(b);
    vector<ll> na(n), nb(m);
    for (ll i = 0; i < n; i++) na[i] = a[i];
    for (ll i = 0; i < m; i++) nb[i] = b[i];
    vector<ll> lres = solve(na, nb);
    ll k = sz(lres);
    vector<int> res(k);
    for (ll i = 0; i < k; i++) res[i] = lres[i];
    return res;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...