제출 #1370068

#제출 시각아이디문제언어결과실행 시간메모리
1370068altayeb_132나일강 (IOI24_nile)C++20
100 / 100
88 ms22040 KiB
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
vector<array<ll, 2>> diff;
map<ll, ll> ans;
vector<array<ll, 3>> edge, edge1;
vector<ll> wt;
ll INF = 1e18;
long long cnt = 0, cnt1 = 0;
#define pb push_back
struct dsu {
    vector<ll> p, sz, m, odd, even, mn;
    vector<ll> sm;
    dsu(ll n, vector<ll> vl) {
        for (ll i = 0; i < n; i++) {
            sm.pb(vl[i]);
            m.pb(INF);
            odd.pb(INF);
            even.pb(vl[i]);
            mn.pb(INF);
        }
        n++;
        p.assign(n, 0);
        sz.assign(n, 1);
        iota(p.begin(), p.end(), 0);
    }
    ll find(ll u) {
        if (u == p[u])
            return u;
        return p[u] = find(p[u]);
    }
    void connect(ll u, ll v) {
        // cout << u << ' ' << v << '\n';
        u = find(u);
        v = find(v);
        if (u == v)
            return;
        swap(u, v);
        cnt -= sm[u] + sm[v];
        if (sz[u] == 1)
            cnt += sm[u];
        if (sz[v] == 1)
            cnt += sm[v];
        if (sz[u] % 2 && sz[u] >= 3)
            cnt += mn[u];
        if (sz[v] % 2 && sz[v] >= 3)
            cnt += mn[v];
        // cout << cnt << '\n';
        // cout << odd[u] << ' ' << even[u] << '\n';
        if (sz[v] % 2) {
            odd[v] = min(odd[v], even[u]);
            even[v] = min(even[v], odd[u]);
        }
        else {
            odd[v] = min(odd[v], odd[u]);
            even[v] = min(even[v], even[u]);
        }
        // cout << odd[v] << ' ' << even[v] << ' ';
        sz[v] += sz[u];
        p[u] = v;
        sm[v] += sm[u];
        cnt += sm[v];
        // cout << cnt << ' ' << sm[v] << ' ' << sm[u] << '\n';
        // cout << u << ' ' << v << ' ';
        m[v] = min(m[v], m[u]);
        // if (v % 2 == 0)
        //     mn[v] = min({(ll)even[v], m[v]});
        // else
            mn[v] = min({(ll)even[v], m[v]});
        // cout << mn[v] << '\n';
        if (sz[v] % 2)
            cnt -= mn[v];
    }
    void connect1(int u, int v) {
        // cout << u << ' ' << v << '\n';
        ll pr = find(v), df = wt[min(u, v) + 1];
        // cout << df << '\n';
        m[pr] = min(m[pr], df);
        if (mn[pr] > m[pr]) {
            if (sz[pr] % 2) {
                // cout << df << '\n';
                cnt += mn[pr];
                cnt -= m[pr];
            }
            mn[pr] = m[pr];
        }
    }
};
vector<long long> calculate_costs(vector<int> w, vector<int> a,vector<int> b, vector<int> e) {
    ll n = w.size(),q = e.size();
    vector<long long> org;
    for (ll i = 0; i < q; i++)
        org.push_back(e[i]);
    sort(e.begin(), e.end());
    ll fl = 0;
    vector<ll> tmp;
    for (ll i = 0; i < n; i++) {
        fl += a[i];
        diff.push_back({w[i], a[i] - b[i]});
    }
    sort(diff.begin(), diff.end());
    for (ll i = 0; i < n - 1; i++) {
        edge.push_back({diff[i + 1][0] - diff[i][0], i, i + 1});
        if (i != n - 2)
            edge1.push_back({diff[i + 2][0] - diff[i][0], i, i + 2});
    }
    for (ll i = 0; i < n; i++)
        tmp.push_back(diff[i][1]);
    wt = tmp;
    dsu ds(n, tmp);
    sort(edge.begin(), edge.end());
    sort(edge1.begin(), edge1.end());
    ll nw = 0, nw1 = 0;
    for (ll i = 0; i < q; i++) {
        while (nw < n - 1 && e[i] >= edge[nw][0]) {
            ds.connect(edge[nw][1], edge[nw][2]);
            nw++;
            // cout << cnt << '\n';
        }
        //
        while (nw1 < n - 2 && e[i] >= edge1[nw1][0]) {
            ds.connect1(edge1[nw1][1], edge1[nw1][2]);
            nw1++;
        }
        ans[e[i]] = fl - cnt;
    }
    for (auto &i : org)
        i = ans[i];
    return org;
}

#include "nile.h"
// int main() {int N;assert(1 == scanf("%d", &N));std::vector<int> W(N), A(N), B(N);for (int i = 0; i < N; i++)assert(3 == scanf("%d%d%d", &W[i], &A[i], &B[i]));int Q;assert(1 == scanf("%d", &Q));std::vector<int> E(Q);for (int j = 0; j < Q; j++)assert(1 == scanf("%d", &E[j]));fclose(stdin);std::vector<long long> R = calculate_costs(W, A, B, E);int S = (int)R.size();for (int j = 0; j < S; j++)printf("%lld\n", R[j]);fclose(stdout);}
#결과 실행 시간메모리채점기 출력
결과를 불러오는 중입니다…
#결과 실행 시간메모리채점기 출력
결과를 불러오는 중입니다…
#결과 실행 시간메모리채점기 출력
결과를 불러오는 중입니다…
#결과 실행 시간메모리채점기 출력
결과를 불러오는 중입니다…
#결과 실행 시간메모리채점기 출력
결과를 불러오는 중입니다…
#결과 실행 시간메모리채점기 출력
결과를 불러오는 중입니다…
#결과 실행 시간메모리채점기 출력
결과를 불러오는 중입니다…
#결과 실행 시간메모리채점기 출력
결과를 불러오는 중입니다…