제출 #563269

#제출 시각아이디문제언어결과실행 시간메모리
563269ngpin04도로 폐쇄 (APIO21_roads)C++14
100 / 100
660 ms292144 KiB
#include <bits/stdc++.h>
#include "roads.h"
#define fi first
#define se second
#define mp make_pair
#define TASK ""
#define bit(x) (1LL << (x))
#define getbit(x, i) (((x) >> (i)) & 1)
#define ALL(x) (x).begin(), (x).end() 
using namespace std;
template <typename T1, typename T2> bool mini(T1 &a, T2 b) {
    if (a > b) {a = b; return true;} return false;
}
template <typename T1, typename T2> bool maxi(T1 &a, T2 b) {
    if (a < b) {a = b; return true;} return false;
}
mt19937_64 rd(chrono::steady_clock::now().time_since_epoch().count());

int rand(int l, int r) {
    return l + rd() % (r - l + 1);
}
const int N = 1e5 + 5; 
const int oo = 1e9;
const long long ooo = 1e18;
const int mod = 1e9 + 7; // 998244353;
const long double pi = acos(-1);

struct trie {
    struct node {
        int cnt; long long tot;
        int ptr[2];
        node (int _cnt = 0, long long _tot = 0) {
            cnt = _cnt;
            tot = _tot;
            for (int i = 0; i < 2; i++)
                ptr[i] = 0;
        }
    };

    bool neg;
    int num, sz;
    vector <node> tree;

    #define ptr(cur, i) tree[cur].ptr[i]
    #define cnt(cur) tree[cur].cnt
    #define tot(cur) tree[cur].tot

    trie(int _neg = 0) {
        neg = _neg;
        num = 0;
        sz = 0;
        tree.emplace_back();
    }

    void add(long long x, int val) {
        // cerr << "trie add " << x << " " << val << "\n";
        sz += val;
        x = abs(x);
        int cur = 0;
        for (int i = 60; i >= 0; i--) {
            int v = getbit(x, i);
            if (!ptr(cur, v)) {
                ptr(cur, v) = ++num;
                tree.emplace_back();
            }
            cur = ptr(cur, v);
            cnt(cur) += val;
            tot(cur) += (val > 0) ? x : -x;
        }
    }

    long long getkth(int k) {
        // cerr << "query " << k << "\n";
        long long res = 0;
        long long tmp = 0;
        int cur = 0;
        for (int i = 60; k > 0 && i >= 0; i--) 
        for (int v = 0; v < 2; v++) {
            int p = ptr(cur, v ^ neg);
            // cerr << cur << " " << (v ^ neg) << " " << p << "\n";
            if (cnt(p) > k) {
                if (v ^ neg)
                    tmp |= bit(i);
                cur = p;
                break;
            }
            k -= cnt(p);
            res += tot(p);
        }
        // cerr << k << "\n";

        if (k > 0)
            res += k * tmp;

        if (neg)
            res = -res;

        return res;
    }
};

vector <int> adj[N];
vector <int> deg[N];
vector <int> g[N];

trie neg[N];

long long dp[N][2];

long long val[N];
int par[N];
int fr[N];
int to[N];
int h[N];
int w[N];
int n;

bool vis[N];
bool flag[N];

void dfs(int u, int p = -1) {
    par[u] = p;
    for (int i : adj[u]) {
        int v = fr[i] ^ to[i] ^ u;
        if (v == p)
            continue;
        dfs(v, u);
    }
}

void addvalue(int v, long long x, int val) {
    if (x < 0)
        neg[v].add(x, val);
}

void add(int u) {
    for (int i : adj[u]) {
        int v = fr[i] ^ to[i] ^ u;
        if (!flag[v]) {
            val[u] += w[i];
            addvalue(u, -w[i], 1);
            continue;
        }

        val[v] -= w[i];
        addvalue(v, -w[i], -1);

        if (v == par[u])
            g[v].push_back(i);
        else
            g[u].push_back(i);
    }

    flag[u] = true;
}

void solve(int u, int k) {
    long long tot = val[u];

    for (int i : g[u]) {
        int v = fr[i] ^ to[i] ^ u;
        solve(v, k);
        tot += dp[v][1] + w[i];
        addvalue(u, dp[v][0] - (dp[v][1] + w[i]), 1);
    }

    int cntneg = neg[u].sz;
    // cerr << tot << "\n";
    // cerr << cntneg << "\n";
    if (cntneg >= k) {
        dp[u][0] = neg[u].getkth(k - 1);
        dp[u][1] = neg[u].getkth(k);    
        // cerr << dp[u][0] << " " << dp[u][1] << "\n";
    } else 
        dp[u][0] = dp[u][1] = neg[u].getkth(cntneg);   

    dp[u][0] += tot;
    dp[u][1] += tot;

    for (int i : g[u]) {
        int v = fr[i] ^ to[i] ^ u;
        addvalue(u, dp[v][0] - (dp[v][1] + w[i]), -1);
    }
}

vector <long long> solve() {
    for (int i = 0; i < n; i++)
        neg[i] = trie(1);

    for (int i = 0; i < n - 1; i++) {
        adj[fr[i]].push_back(i);
        adj[to[i]].push_back(i);
    }

    for (int i = 0; i < n; i++)
        deg[adj[i].size()].push_back(i);

    dfs(1);

    vector <long long> res;
    vector <int> node;

    for (int k = n - 1; k >= 1; k--) {
        // cerr << "deg = " << k << "\n";
        for (int v : deg[k]) {
            add(v);
            node.push_back(v);
            // cerr << "add " <<  v << "\n";
        }
        long long tmp = 0;
    
        for (int v : node) if (!flag[par[v]]) {
            solve(v, k);
            tmp += min(dp[v][0], dp[v][1]);
        }

        res.push_back(tmp);
    }

    res.push_back(accumulate(w, w + (n - 1), 0LL));
    reverse(ALL(res));
    return res;
}

vector<long long> minimum_closure_costs(int N, vector<int> U,
                                             vector<int> V,
                                             vector<int> W) {
    n = N;
    for (int i = 0; i < n - 1; i++) {
        fr[i] = U[i];
        to[i] = V[i];
        w[i] = W[i];
    }

    return solve();
}

//#include "grader.cpp"
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...