제출 #428981

#제출 시각아이디문제언어결과실행 시간메모리
428981yuto1115전선 연결 (IOI17_wiring)C++17
100 / 100
137 ms22484 KiB
#include "wiring.h"
#include<bits/stdc++.h>
#define rep(i, n) for(ll i = 0; i < ll(n); i++)
#define rep2(i, s, n) for(ll i = ll(s); i < ll(n); i++)
#define rrep(i, n) for(ll i = ll(n)-1; i >= 0; i--)
#define pb push_back
#define eb emplace_back
#define all(a) a.begin(),a.end()
#define rall(a) a.rbegin(),a.rend()
using namespace std;
using ll = long long;
using P = pair<int, int>;
using vi = vector<int>;
using vvi = vector<vi>;
using vl = vector<ll>;
using vvl = vector<vl>;
using vp = vector<P>;
using vvp = vector<vp>;
using vb = vector<bool>;
using vvb = vector<vb>;
using vs = vector<string>;
const int inf = 1001001001;
const ll linf = 1001001001001001001;

template<class T>
bool chmin(T &a, T b) {
    if (a > b) {
        a = b;
        return true;
    }
    return false;
}

template<class T>
bool chmax(T &a, T b) {
    if (a < b) {
        a = b;
        return true;
    }
    return false;
}

ll eval(ll a, ll b, int x) {
    return a * x + b;
}

class segtree {
    int n;
    // ax + b
    vl a, b;
    vi ls, rs;
    
    void _update(int k, ll na, ll nb) {
        int l = ls[k], r = rs[k], m = (l + r) / 2;
        if (l + 1 == r) {
            if (eval(a[k], b[k], l) > eval(na, nb, l)) {
                a[k] = na;
                b[k] = nb;
            }
            return;
        }
        if (eval(a[k], b[k], l) <= eval(na, nb, l) and eval(a[k], b[k], r) <= eval(na, nb, r)) return;
        if (eval(a[k], b[k], l) >= eval(na, nb, l) and eval(a[k], b[k], r) >= eval(na, nb, r)) {
            a[k] = na;
            b[k] = nb;
            return;
        }
        if (eval(a[k], b[k], m) > eval(na, nb, m)) {
            swap(a[k], na);
            swap(b[k], nb);
        }
        if (eval(a[k], b[k], l) > eval(na, nb, l)) {
            _update(2 * k, na, nb);
            return;
        }
        if (eval(a[k], b[k], r) > eval(na, nb, r)) {
            _update(2 * k + 1, na, nb);
            return;
        }
        assert(false);
    }

public:
    segtree(int _n) {
        n = 1;
        while (n < _n) n *= 2;
        a.assign(2 * n, 0);
        b.assign(2 * n, linf);
        ls.resize(2 * n);
        rs.resize(2 * n);
        rep(i, n) {
            ls[n + i] = i;
            rs[n + i] = i + 1;
        }
        rrep(i, n) {
            ls[i] = ls[2 * i];
            rs[i] = rs[2 * i + 1];
        }
    }
    
    ll get(int i) {
        assert(0 <= i and i < n);
        int x = i;
        i += n;
        ll res = linf;
        while (i >= 1) {
            chmin(res, eval(a[i], b[i], x));
            i >>= 1;
        }
        return res;
    }
    
    void update(int l, int r, ll na, ll nb) {
        assert(0 <= l and l <= r and r <= n);
        l += n, r += n;
        while (l < r) {
            if (l & 1) _update(l++, na, nb);
            if (r & 1) _update(--r, na, nb);
            l >>= 1, r >>= 1;
        }
    }
};

ll min_total_length(vi r, vi b) {
    int n = r.size();
    int m = b.size();
    ll ans = 0;
    if (n >= m) {
        rep(i, n - (m - 1)) ans += b[0] - r[i];
        rep(i, m - 1) ans += b[i + 1] - r[n - (m - 1) + i];
    } else {
        rep(i, n - 1) ans += b[i] - r[i];
        rep2(i, n - 1, m) ans += b[i] - r.back();
    }
    vector<pair<ll, int>> v;
    rep(i, n) v.eb(r[i], 0);
    rep(i, m) v.eb(b[i], 1);
    sort(all(v));
    vi ps;
    rep(i, n + m - 1) if (v[i].second != v[i + 1].second) ps.pb(i + 1);
    ps.pb(n + m);
    vl sum_r(n + m + 1);
    rep(i, n + m) {
        sum_r[i + 1] = sum_r[i] + (v[i].second == 0 ? v[i].first : -v[i].first);
    }
    vl sum_b(n + m + 1);
    rep(i, n + m) {
        sum_b[i + 1] = sum_b[i] + (v[i].second == 1 ? v[i].first : -v[i].first);
    }
    segtree st(n + m + 10);
    st.update(0, 1, 0, 0);
    rep(i, n + m) {
        ll now = st.get(i);
        if (i > 0) {
            if (v[i - 1].second == 0) now += sum_r[i];
            else now += sum_b[i];
        }
        auto it = upper_bound(all(ps), i);
        int x = *it;
        it++;
        if (it == ps.end()) continue;
        int y = *it;
        ll sub;
        if (v[i].second == 0) sub = sum_b[i];
        else sub = sum_r[i];
        int p = 2 * x - i;
        if (y <= p) {
            st.update(x + 1, y + 1, -v[x].first, now + v[x].first * p - sub);
        } else {
            st.update(x + 1, p + 1, -v[x].first, now + v[x].first * p - sub);
            st.update(p + 1, y + 1, -v[x - 1].first, now + v[x - 1].first * p - sub);
        }
        if (y - x > 1) continue;
        it++;
        if (it == ps.end()) continue;
        int z = *it;
        if (v[i].second == 0) {
            now += sum_b[x] - sum_b[i];
            sub = sum_r[y];
        } else {
            now += sum_r[x] - sum_r[i];
            sub = sum_b[y];
        }
        now += v[x].first * (x - i);
        st.update(y + 1, z + 1, -v[x].first, now + v[x].first * y - sub);
    }
    ll res = st.get(n + m);
    if (v.back().second == 0) res += sum_r[n + m];
    else res += sum_b[n + m];
    return res;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...