제출 #716607

#제출 시각아이디문제언어결과실행 시간메모리
716607StickfishPaths (RMI21_paths)C++17
36 / 100
632 ms524288 KiB
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
using ll = long long;

struct node {
    node* l = nullptr;
    node* r = nullptr;
    int prior;
    ll val = 0;
    ll sm = 0;
    int sz = 0;
};

node* copy(node* nd) {
    node* v = new node();
    v->l = nd->l;
    v->r = nd->r;
    v->val = nd->val;
    v->sm = nd->sm;
    v->prior = nd->prior;
    v->sz = nd->sz;
    return v;
}

int get_size(node* nd) {
    if (!nd)
        return 0;
    return nd->sz;
}

ll get_sum(node* nd) {
    if (!nd)
        return 0;
    return nd->sm;
}

void update(node* nd) {
    nd->sz = 1 + get_size(nd->l) + get_size(nd->r);
    nd->sm = nd->val + get_sum(nd->l) + get_sum(nd->r);
}

node* add_last(node* nd, ll val) {
    ll ndsm = get_sum(nd);
    nd = copy(nd);
    if (!nd->r) {
        nd->val += val;
        nd->sm += val;
        //cout << "( " << ndsm << " + " << val << " -> " << get_sum(nd) << ") ";
        return nd;
    }
    nd->r = add_last(nd->r, val);
    update(nd);
    return nd;
}

node* merge(node* l, node* r) {
    if (!l)
        return r;
    if (!r)
        return l;
    if (l->prior > r->prior) {
        auto lr = merge(l->r, r);
        l = copy(l);
        l->r = lr;
        update(l);
        return l;
    } else {
        auto rl = merge(l, r->l);
        r = copy(r);
        r->l = rl;
        update(r);
        return r;
    }
}

pair<node*, node*> split_value(node* nd, ll val) {
    if (!nd)
        return {nullptr, nullptr};
    if (val <= nd->val) {
        auto [l, r] = split_value(nd->l, val);
        nd = copy(nd);
        nd->l = r;
        update(nd);
        return {l, nd};
    } else {
        auto [l, r] = split_value(nd->r, val);
        nd = copy(nd);
        nd->r = l;
        update(nd);
        return {nd, r};
    }
}

pair<node*, node*> split_size(node* nd, int sz) {
    if (!nd)
        return {nullptr, nullptr};
    if (sz <= get_size(nd->l)) {
        auto [l, r] = split_size(nd->l, sz);
        nd = copy(nd);
        nd->l = r;
        update(nd);
        return {l, nd};
    } else {
        auto [l, r] = split_size(nd->r, sz - 1 - get_size(nd->l));
        nd = copy(nd);
        nd->r = l;
        update(nd);
        return {nd, r};
    }
}

node* insert(node* nd, ll val) {
    auto [l, r] = split_value(nd, val);
    node* v = new node();
    v->val = v->sm = val;
    v->sz = 1;
    v->prior = rand();
    return merge(l, merge(v, r));
}

void get_values(node* nd, vector<ll>& ans) {
    if (!nd)
        return;
    get_values(nd->r, ans);
    ans.push_back(nd->val);
    get_values(nd->l, ans);
}

const int MAXN = 1e5 + 123;
vector<pair<int, int>> edg[MAXN];
ll cost[MAXN * 3];
pair<int, int> edges[MAXN * 3];
node* dp[MAXN * 3];
//vector<ll> dp[MAXN * 3];

node* merge_nodes(vector<node*> nds, int k) {
    node* basend = nullptr;
    for (auto nd : nds) {
        if (get_size(basend) < get_size(nd))
            basend = nd;
    }
    node* ans = basend;
    for (auto nd : nds) {
        if (nd == basend)
            continue;
        vector<ll> add;
        get_values(nd, add);
        for (auto x : add) {
            ans = insert(ans, x);
        }
    }
    if (get_size(ans) > k) {
        vector<ll> vals;
        get_values(ans, vals);
        ans = split_size(ans, get_size(ans) - k).second;
        vals.clear();
        get_values(ans, vals);
    }
    return ans;
}

void get_dp(int e, int k) {
    if (get_size(dp[e]))
        return;
    auto [rt, v] = edges[e];
    //if (k == 1) {
        //dp[e] = {0};
    //}
    vector<node*> nds;
    for (auto [u, ne] : edg[v]) {
        if (u == rt)
            continue;
        get_dp(ne, k);
        nds.push_back(dp[ne]);
        //if (k == 1) {
            //dp[e][0] = max(dp[e][0], dp[ne][0]);
            //continue;
        //}
        //vector<ll> ndp(dp[e].size() + dp[ne].size());
        //merge(dp[e].rbegin(), dp[e].rend(), dp[ne].rbegin(), dp[ne].rend(), ndp.rbegin());
        //ndp.resize(min(int(ndp.size()), k));
        //dp[e] = ndp;
    }
    //cout << "# ";
    //for (auto nd : nds)
        //cout << get_sum(nd) << ' ';
    dp[e] = merge_nodes(nds, k);
    //cout << ": " << get_sum(dp[e]);
    if (get_size(dp[e])) {
        //cout << get_sum(dp[e]) << " + " << cost[e] << " -> ";
        dp[e] = add_last(dp[e], cost[e]);
        //cout << get_sum(dp[e]) << endl;
        //cout << "? " << cost[e] << ' ' << get_sum(dp[e]) << '\n';
    } else {
        dp[e] = insert(nullptr, cost[e]);
        //cout << "! " << cost[e] << ' ' << get_sum(dp[e]) << '\n';
        //dp[e] = {cost[e]};
    }
    //cout << " | " << get_sum(dp[e]) << endl;
}

signed main() {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);
    int n, k;
    cin >> n >> k;
    //assert(k == 1);
    for (int i = 0; i + 1 < n; ++i) {
        int u, v, c;
        cin >> u >> v >> c;
        --u, --v;
        edg[u].push_back({v, i * 2});
        edg[v].push_back({u, i * 2 + 1});
        edges[i * 2] = {u, v};
        edges[i * 2 + 1] = {v, u};
        cost[i * 2] = cost[i * 2 + 1] = c;
    }
    for (int e = 2 * n; e < 3 * n; ++e) {
        edges[e] = {n, e - 2 * n};
        get_dp(e, k);
        cout << get_sum(dp[e]) << '\n';
        //ll ans = 0;
        //for (auto x : dp[e])
            //ans += x;
        //cout << ans << '\n';
    }
}

컴파일 시 표준 에러 (stderr) 메시지

Main.cpp: In function 'node* add_last(node*, ll)':
Main.cpp:45:8: warning: unused variable 'ndsm' [-Wunused-variable]
   45 |     ll ndsm = get_sum(nd);
      |        ^~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...