제출 #1337963

#제출 시각아이디문제언어결과실행 시간메모리
1337963lvjosavFlooding Wall (BOI24_wall)C++20
12 / 100
3815 ms73932 KiB
#pragma GCC optimize("O3")
#include <bits/stdc++.h>
using namespace std;

#define rep(i, a, b) for(int i = a; i < (b); ++i)
#define all(x) begin(x), end(x)
#define sz(x) (int)(x).size()
typedef long long ll;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
typedef vector<int> vi;
typedef vector<ll> vl;
typedef vector<vector<ll>> vvl;
const ll inf = LONG_LONG_MAX / 2;

const ll mod = 1000000007;

struct segtree {
    ll nt = 1;
    vector<ll> sum, lz_mul;
    
    segtree(ll n) {
        nt = 1;
        while (nt < n) nt *= 2;
        sum = vector<ll>(2*nt);
        lz_mul = vector<ll>(2*nt, 1);
    }

    void prop(ll k) {
        (sum[k] *= lz_mul[k]) %= mod;
        if (k < nt) {
            (lz_mul[2*k] *= lz_mul[k]) %= mod;
            (lz_mul[2*k+1] *= lz_mul[k]) %= mod;
        }
        lz_mul[k] = 1;
    }

    void point_add(ll x, ll v) { return point_add(1, 0, nt-1, x, v); }
    void point_add(ll k, ll tl, ll tr, ll x, ll v) {
        prop(k);
        if (x < tl || x > tr) return;
        if (tl == tr) {
            (sum[k] += v) %= mod;
            return;
        }
        ll mid = tl + (tr - tl) / 2;
        point_add(2*k, tl, mid, x, v);
        point_add(2*k+1, mid+1, tr, x, v);
        sum[k] = (sum[2*k] + sum[2*k+1]) % mod;
    }

    void range_mul(ll l, ll r, ll v) { return range_mul(1, 0, nt-1, l, r, v); }
    void range_mul(ll k, ll tl, ll tr, ll l, ll r, ll v) {
        prop(k);
        if (r < tl || l > tr) return;
        if (l <= tl && r >= tr) {
            (lz_mul[k] *= v) %= mod;
            prop(k);
            return;
        }
        ll mid = tl + (tr - tl) / 2;
        range_mul(2*k, tl, mid, l, r, v);
        range_mul(2*k+1, mid+1, tr, l, r, v);
        sum[k] = (sum[2*k] + sum[2*k+1]) % mod;
    }

    ll range_sum(ll r) { return range_sum(1, 0, nt-1, r) % mod; }
    ll range_sum(ll k, ll tl, ll tr, ll r) {
        prop(k);
        if (r < tl) return 0;
        if (r >= tr) return sum[k];
        ll mid = tl + (tr - tl) / 2;
        if (r <= mid) return range_sum(2*k, tl, mid, r);
        else return (sum[2*k] + range_sum(2*k+1, mid+1, tr, r));
    }

    ll get(ll i) {
        return (range_sum(i) - range_sum(i-1) + mod) % mod;
    }
};

ll solve_with_left_wall(ll n, vector<ll> &a, vector<ll> &b, vector<ll> &c) {
    segtree d(2*n), di(2*n);
    ll base = 0;
    auto eval = [&](ll j) {
        return (base + c[j] * d.range_sum(j-1) - di.range_sum(j-1) + mod) % mod;
    };
    auto add = [&](ll j, ll v) {
        d.point_add(j, v);
        di.point_add(j, (c[j] * v) % mod);
    };
    auto mul = [&](ll l, ll r, ll v) {
        d.range_mul(l, r, v);
        di.range_mul(l, r, v);
    };
    ll pw2 = 1;
    for (ll i = 0; i < n; i++) {
        ll av = eval(a[i]), bv = eval(b[i]);
        base = (av + bv) % mod;
        ll sa = d.range_sum(a[i]-1);
        ll sb = d.range_sum(b[i]-1);
        mul(b[i], 2*n-1, 2); mul(0, a[i]-1, 0);
        add(b[i], sb+pw2); add(a[i], sa+pw2);
        (pw2 *= 2) %= mod;
    }
    return base;
}

ll solve_filled(ll n, vector<ll> &a, vector<ll> &b, vector<ll> &c) {
    segtree hist(2*n);
    hist.point_add(a[0], 1);
    hist.point_add(b[0], 1);
    for (ll i = 1; i < n; i++) {
        ll sa = hist.range_sum(a[i]-1);
        ll sb = hist.range_sum(b[i]-1);
        hist.range_mul(b[i], 2*n-1, 2);
        hist.range_mul(0, a[i]-1, 0);
        hist.point_add(a[i], sa);
        hist.point_add(b[i], sb);
    }
    ll res = 0;
    for (ll i = 0; i < sz(c); i++) {
        ll freq = hist.get(i);
        ll rect = (n * (c.back() - c[i]) % mod);
        (res -= rect * freq) %= mod;
    }
    res = (res % mod + mod) % mod;
    ll pw2 = 1;
    for (ll i = 0; i < n-1; i++) {
        pw2 = (2 * pw2) % mod;
    }
    for (ll i = 0; i < n; i++) {
        ll v = 0;
        v += c.back() - c[a[i]];
        v += c.back() - c[b[i]];
        (res += pw2 * v) %= mod;
    }
    return res;
}

void solve() {
    ll n; cin >> n;
    vector<ll> a(n), b(n);
    for (auto &e : a) cin >> e;
    for (auto &e : b) cin >> e;
    set<ll> vs;
    for (ll i = 0; i < n; i++) {
        if (a[i] > b[i]) swap(a[i], b[i]);
        vs.insert(a[i]);
        vs.insert(b[i]);
    }
    vector<ll> c(all(vs));
    for (auto &e : a) e = lower_bound(all(c), e) - c.begin();
    for (auto &e : b) e = lower_bound(all(c), e) - c.begin();

    ll res = mod - solve_filled(n, a, b, c);
    res += solve_with_left_wall(n, a, b, c);
    reverse(all(a)); reverse(all(b));
    res += solve_with_left_wall(n, a, b, c);
    cout << res % mod << '\n';
}

signed main() {
    cin.tie(0)->sync_with_stdio(0);
    solve();
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...