제출 #521769

#제출 시각아이디문제언어결과실행 시간메모리
521769maomao90도로 폐쇄 (APIO21_roads)C++17
100 / 100
532 ms168448 KiB

// Fight the good fight of the faith. Take hold of the 
// eternal life to which you were called when you made 
// your good confession in the presence of many witnesses
// 1 Timonthy 6:12
#include <bits/stdc++.h> 
using namespace std;

template <class T>
inline bool mnto(T& a, T b) {return a > b ? a = b, 1 : 0;}
template <class T>
inline bool mxto(T& a, T b) {return a < b ? a = b, 1: 0;}
#define REP(i, s, e) for (int i = s; i < e; i++)
#define RREP(i, s, e) for (int i = s; i >= e; i--)
typedef long long ll;
typedef long double ld;
#define MP make_pair
#define FI first
#define SE second
typedef pair<int, int> ii;
typedef pair<ll, ll> pll;
#define MT make_tuple
typedef tuple<int, int, int> iii;
#define ALL(_a) _a.begin(), _a.end()
#define pb push_back
typedef vector<int> vi;
typedef vector<ll> vll;
typedef vector<ii> vii;

#ifdef DEBUG
#define debug(args...) printf(args)
#else
#define debug(args...)
#endif

#define INF 1000000005ll
#define LINF 1000000000000000005ll
#define MAXN 100005

int n;
vii tadj[MAXN], adj[MAXN];
vi deg[MAXN];
bool on[MAXN];
set<int> useful;
bool vis[MAXN];

struct ST {
    vi cnt;
    vll sm;
    vi lc, rc;
    ST() {
        cnt.pb(0);
        sm.pb(0);
        lc.pb(-1);
        rc.pb(-1);
    }
    void add(ll p, int x, int u = 0, ll lo = 0, ll hi = INF) {
        if (lo == hi) {
            cnt[u] += x;
            sm[u] += x * p;
            return;
        }
        ll mid = lo + hi >> 1;
        if (p <= mid) {
            if (lc[u] == -1) {
                lc[u] = cnt.size();
                cnt.pb(0);
                sm.pb(0);
                lc.pb(-1);
                rc.pb(-1);
            }
            add(p, x, lc[u], lo, mid);
        } else {
            if (rc[u] == -1) {
                rc[u] = cnt.size();
                cnt.pb(0);
                sm.pb(0);
                lc.pb(-1);
                rc.pb(-1);
            }
            add(p, x, rc[u], mid + 1, hi);
        }
        cnt[u] = (lc[u] == -1 ? 0 : cnt[lc[u]]) + (rc[u] == -1 ? 0 : cnt[rc[u]]);
        sm[u] = (lc[u] == -1 ? 0 : sm[lc[u]]) + (rc[u] == -1 ? 0 : sm[rc[u]]);
    }
    ll qry(int k, int u = 0, ll lo = 0, ll hi = INF) {
        if (lo == hi) {
            return lo * k;
        }
        assert(u != -1);
        if (cnt[u] == k) {
            return sm[u];
        }
        ll mid = lo + hi >> 1;
        if (lc[u] == -1) {
            return qry(k, rc[u], mid + 1, hi);
        }
        if (cnt[lc[u]] >= k) {
            return qry(k, lc[u], lo, mid);
        } else {
            return qry(k - cnt[lc[u]], rc[u], mid + 1, hi) + sm[lc[u]];
        }
    }
} st[MAXN];

int k;
// first is never close parent, second is close parent
pll dp(int u, int p) {
    vll vec;
    ll cur = 0;
    int sze = st[u].cnt[0];
    assert(sze + adj[u].size() == tadj[u].size());
    for (auto [v, w] : adj[u]) {
        if (v == p) continue;
        vis[v] = 1;
        auto [a, b] = dp(v, u);
        b += w;
        if (b - a < 0) {
            cur += b;
        } else {
            debug("  +%lld\n", b - a);
            vec.pb(b - a);
            cur += a;
            sze++;
        }
    }
    debug(" %d %d %lld\n", u, sze, cur);
    for (ll i : vec) {
        st[u].add(i, 1);
    }
    auto solve = [&] (int k) {
        if (sze <= k) {
            return cur;
        }
        int x = sze - k;
        debug(" %d %d\n", x, st[u].cnt[0]);
        return cur + st[u].qry(x);
    };
    ll ra = solve(k - 1), rb = solve(k);
    for (ll i : vec) {
        st[u].add(i, -1);
    }
    return MP(ra, rb);
}

vll minimum_closure_costs(int n, vi u, vi v, vi w) {
    ::n = n;
    ll sm = 0;
    REP (i, 0, n - 1) {
        tadj[u[i]].pb(MP(v[i], w[i]));
        tadj[v[i]].pb(MP(u[i], w[i]));
        sm += w[i];
    }
    REP (i, 0, n) {
        deg[tadj[i].size()].pb(i);
    }
    vll ans(n, 0);
    RREP (k, n - 1, 1) {
        ::k = k;
        debug("%d\n", k);
        for (int i : deg[k + 1]) {
            useful.insert(i);
            on[i] = 1;
            for (auto [u, w] : tadj[i]) {
                if (on[u]) {
                    adj[u].pb(MP(i, w));
                    adj[i].pb(MP(u, w));
                    st[u].add(w, -1);
                } else {
                    if (i == 23) {
                        debug("+%d %02d %d\n", i, u, w);
                    }
                    st[i].add(w, 1);
                }
            }
        }
        for (int i : useful) {
            if (vis[i]) continue;
            vis[i] = 1;
            auto [a, b] = dp(i, -1);
            ans[k] += b;
        }
        for (int i : useful) {
            vis[i] = 0;
        }
    }
    ans[0] = sm;
    return ans;
}

컴파일 시 표준 에러 (stderr) 메시지

roads.cpp: In member function 'void ST::add(ll, int, int, ll, ll)':
roads.cpp:63:21: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
   63 |         ll mid = lo + hi >> 1;
      |                  ~~~^~~~
roads.cpp: In member function 'll ST::qry(int, int, ll, ll)':
roads.cpp:94:21: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
   94 |         ll mid = lo + hi >> 1;
      |                  ~~~^~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...