Submission #1346453

#TimeUsernameProblemLanguageResultExecution timeMemory
1346453ZicrusHieroglyphs (IOI24_hieroglyphs)C++20
100 / 100
196 ms40400 KiB
#include "hieroglyphs.h"
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef long double ld;
typedef pair<ll, ll> pll;
typedef pair<int, int> pii;
typedef vector<int> vi;
#define rep(i, a, b) for(int i = a; i < (b); ++i)
#define all(x) x.begin(), x.end()
#define sz(x) (ll)(x).size()
constexpr ll inf = 4e18;
mt19937 mt(time(0));

vector<int> imp() {
    return {-1};
}

struct segtree {
    ll nt = 0;
    vector<ll> tree;

    segtree(ll n) {
        nt = 1;
        while (nt < n) nt *= 2;
        tree = vector<ll>(2*nt, inf);
    }

    void point_set(ll i, ll v) { return point_set(1, 0, nt-1, i, v); }
    void point_set(ll k, ll tl, ll tr, ll i, ll v) {
        if (tl == tr) return tree[k] = v, void();
        ll mid = tl + (tr - tl) / 2;
        if (i <= mid) point_set(2*k, tl, mid, i, v);
        else point_set(2*k+1, mid+1, tr, i, v);
        tree[k] = min(tree[2*k], tree[2*k+1]);
    }

    ll range_min(ll l, ll r) { return range_min(1, 0, nt-1, l, r); }
    ll range_min(ll k, ll tl, ll tr, ll l, ll r) {
        if (r < tl || l > tr) return inf;
        if (l <= tl && r >= tr) return tree[k];
        ll mid = tl + (tr - tl) / 2;
        return min(range_min(2*k, tl, mid, l, r), range_min(2*k+1, mid+1, tr, l, r));
    }
};

struct symbol {
    ll c = -1;
    ll l0 = inf, r0 = -inf, l1 = inf, r1 = -inf;

    symbol(ll c, ll l0, ll r0, ll l1, ll r1) : c(c), l0(l0), r0(r0), l1(l1), r1(r1) { }

    bool operator<(const symbol &o) const {
        return l0 < o.r0 && l1 < o.r1;
    }
};

vector<symbol> get_symbols(vector<int> &a, vector<int> &b) {
    ll n = sz(a), m = sz(b);
    vector<vector<ll>> ids0(2e5+1), ids1(2e5+1);
    for (ll i = 0; i < n; i++) ids0[a[i]].push_back(i);
    for (ll i = 0; i < m; i++) ids1[b[i]].push_back(i);
    
    vector<symbol> syms;
    for (ll i = 0; i <= 2e5; i++) {
        ll diff = abs(sz(ids1[i]) - sz(ids0[i]));
        if (sz(ids0[i]) < sz(ids1[i])) {
            for (ll j = 0; j < sz(ids0[i]); j++) {
                syms.emplace_back(i, ids0[i][j], ids0[i][j], ids1[i][j], ids1[i][j+diff]);
            }
        }
        else {
            for (ll j = 0; j < sz(ids1[i]); j++) {
                syms.emplace_back(i, ids0[i][j], ids0[i][j+diff], ids1[i][j], ids1[i][j]);
            }
        }
    }
    return syms;
}

vector<ll> eq_prev(vector<int> &a) {
    ll n = sz(a);
    vector<ll> last_seen(2e5+1, -1);
    vector<ll> prev(n);
    for (ll i = 0; i < n; i++) {
        prev[i] = last_seen[a[i]];
        last_seen[a[i]] = i;
    }
    return prev;
}

vector<set<ll>> val_sets(vector<int> &a) {
    ll n = sz(a);
    vector<set<ll>> sets(2e5+1);
    for (ll i = 0; i < n; i++) {
        sets[a[i]].insert(i);
    }
    return sets;
}

ll strict_next(vector<set<ll>> &sets, ll i, ll v) {
    auto it = sets[v].upper_bound(i);
    return (it == sets[v].end()) ? inf : *it;
}

bool matching_valid(vector<int> &a, vector<int> &b, vector<int> &res, vector<ll> &upper, vector<ll> &lower) {
    ll k = sz(res);
    vector<set<ll>> a_sets = val_sets(a);
    vector<set<ll>> b_sets = val_sets(b);
    vector<ll> a_prev = eq_prev(a);

    segtree dp(k);
    vector<ll> last_seen(2e5+1, -1);
    for (ll i = 0; i < k; i++) {
        ll pu = last_seen[res[i]];
        ll pba = lower_bound(all(upper), a_prev[upper[i]]) - upper.begin() - 1;
        if (strict_next(b_sets, dp.range_min(pu, pba), res[i]) < lower[i]) return false;
        ll v = (pu == -1) ? -1 : dp.range_min(pu, i-1);
        dp.point_set(i, strict_next(b_sets, v, res[i]));
        last_seen[res[i]] = i;
    }

    for (ll s = 0; s <= 2e5; s++) {
        if (last_seen[s] == -1) continue;
        ll la = *--a_sets[s].end();
        ll lb = *--b_sets[s].end();
        ll lk = last_seen[s];
        ll pba = lower_bound(all(upper), la) - upper.begin() - 1;
        if (dp.range_min(lk, pba) < lb) return false;
    }
    return true;
}

vector<int> ucs(vector<int> a, vector<int> b) {
    ll n = sz(a), m = sz(b);
    vector<symbol> syms = get_symbols(a, b);
    partial_sort(all(syms), syms.end()); ll k = sz(syms);

    ll ai = 0, bi = 0;
    vector<int> res(k);
    vector<ll> upper(k), lower(k);
    for (ll i = 0; i < k; i++) {
        res[i] = syms[i].c;
        while (ai < n && a[ai] != res[i]) ai++;
        while (bi < m && b[bi] != res[i]) bi++;
        upper[i] = ai, lower[i] = bi;
        if (ai++ >= n || bi++ >= m) return imp();
    }
    
    if (!matching_valid(a, b, res, upper, lower)) return imp();
    if (!matching_valid(b, a, res, lower, upper)) return imp();
    return res;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...