답안 #716607

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
716607 2023-03-30T14:28:32 Z Stickfish Paths (RMI21_paths) C++17
36 / 100
600 ms 524288 KB
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
using ll = long long;

struct node {
    node* l = nullptr;
    node* r = nullptr;
    int prior;
    ll val = 0;
    ll sm = 0;
    int sz = 0;
};

node* copy(node* nd) {
    node* v = new node();
    v->l = nd->l;
    v->r = nd->r;
    v->val = nd->val;
    v->sm = nd->sm;
    v->prior = nd->prior;
    v->sz = nd->sz;
    return v;
}

int get_size(node* nd) {
    if (!nd)
        return 0;
    return nd->sz;
}

ll get_sum(node* nd) {
    if (!nd)
        return 0;
    return nd->sm;
}

void update(node* nd) {
    nd->sz = 1 + get_size(nd->l) + get_size(nd->r);
    nd->sm = nd->val + get_sum(nd->l) + get_sum(nd->r);
}

node* add_last(node* nd, ll val) {
    ll ndsm = get_sum(nd);
    nd = copy(nd);
    if (!nd->r) {
        nd->val += val;
        nd->sm += val;
        //cout << "( " << ndsm << " + " << val << " -> " << get_sum(nd) << ") ";
        return nd;
    }
    nd->r = add_last(nd->r, val);
    update(nd);
    return nd;
}

node* merge(node* l, node* r) {
    if (!l)
        return r;
    if (!r)
        return l;
    if (l->prior > r->prior) {
        auto lr = merge(l->r, r);
        l = copy(l);
        l->r = lr;
        update(l);
        return l;
    } else {
        auto rl = merge(l, r->l);
        r = copy(r);
        r->l = rl;
        update(r);
        return r;
    }
}

pair<node*, node*> split_value(node* nd, ll val) {
    if (!nd)
        return {nullptr, nullptr};
    if (val <= nd->val) {
        auto [l, r] = split_value(nd->l, val);
        nd = copy(nd);
        nd->l = r;
        update(nd);
        return {l, nd};
    } else {
        auto [l, r] = split_value(nd->r, val);
        nd = copy(nd);
        nd->r = l;
        update(nd);
        return {nd, r};
    }
}

pair<node*, node*> split_size(node* nd, int sz) {
    if (!nd)
        return {nullptr, nullptr};
    if (sz <= get_size(nd->l)) {
        auto [l, r] = split_size(nd->l, sz);
        nd = copy(nd);
        nd->l = r;
        update(nd);
        return {l, nd};
    } else {
        auto [l, r] = split_size(nd->r, sz - 1 - get_size(nd->l));
        nd = copy(nd);
        nd->r = l;
        update(nd);
        return {nd, r};
    }
}

node* insert(node* nd, ll val) {
    auto [l, r] = split_value(nd, val);
    node* v = new node();
    v->val = v->sm = val;
    v->sz = 1;
    v->prior = rand();
    return merge(l, merge(v, r));
}

void get_values(node* nd, vector<ll>& ans) {
    if (!nd)
        return;
    get_values(nd->r, ans);
    ans.push_back(nd->val);
    get_values(nd->l, ans);
}

const int MAXN = 1e5 + 123;
vector<pair<int, int>> edg[MAXN];
ll cost[MAXN * 3];
pair<int, int> edges[MAXN * 3];
node* dp[MAXN * 3];
//vector<ll> dp[MAXN * 3];

node* merge_nodes(vector<node*> nds, int k) {
    node* basend = nullptr;
    for (auto nd : nds) {
        if (get_size(basend) < get_size(nd))
            basend = nd;
    }
    node* ans = basend;
    for (auto nd : nds) {
        if (nd == basend)
            continue;
        vector<ll> add;
        get_values(nd, add);
        for (auto x : add) {
            ans = insert(ans, x);
        }
    }
    if (get_size(ans) > k) {
        vector<ll> vals;
        get_values(ans, vals);
        ans = split_size(ans, get_size(ans) - k).second;
        vals.clear();
        get_values(ans, vals);
    }
    return ans;
}

void get_dp(int e, int k) {
    if (get_size(dp[e]))
        return;
    auto [rt, v] = edges[e];
    //if (k == 1) {
        //dp[e] = {0};
    //}
    vector<node*> nds;
    for (auto [u, ne] : edg[v]) {
        if (u == rt)
            continue;
        get_dp(ne, k);
        nds.push_back(dp[ne]);
        //if (k == 1) {
            //dp[e][0] = max(dp[e][0], dp[ne][0]);
            //continue;
        //}
        //vector<ll> ndp(dp[e].size() + dp[ne].size());
        //merge(dp[e].rbegin(), dp[e].rend(), dp[ne].rbegin(), dp[ne].rend(), ndp.rbegin());
        //ndp.resize(min(int(ndp.size()), k));
        //dp[e] = ndp;
    }
    //cout << "# ";
    //for (auto nd : nds)
        //cout << get_sum(nd) << ' ';
    dp[e] = merge_nodes(nds, k);
    //cout << ": " << get_sum(dp[e]);
    if (get_size(dp[e])) {
        //cout << get_sum(dp[e]) << " + " << cost[e] << " -> ";
        dp[e] = add_last(dp[e], cost[e]);
        //cout << get_sum(dp[e]) << endl;
        //cout << "? " << cost[e] << ' ' << get_sum(dp[e]) << '\n';
    } else {
        dp[e] = insert(nullptr, cost[e]);
        //cout << "! " << cost[e] << ' ' << get_sum(dp[e]) << '\n';
        //dp[e] = {cost[e]};
    }
    //cout << " | " << get_sum(dp[e]) << endl;
}

signed main() {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);
    int n, k;
    cin >> n >> k;
    //assert(k == 1);
    for (int i = 0; i + 1 < n; ++i) {
        int u, v, c;
        cin >> u >> v >> c;
        --u, --v;
        edg[u].push_back({v, i * 2});
        edg[v].push_back({u, i * 2 + 1});
        edges[i * 2] = {u, v};
        edges[i * 2 + 1] = {v, u};
        cost[i * 2] = cost[i * 2 + 1] = c;
    }
    for (int e = 2 * n; e < 3 * n; ++e) {
        edges[e] = {n, e - 2 * n};
        get_dp(e, k);
        cout << get_sum(dp[e]) << '\n';
        //ll ans = 0;
        //for (auto x : dp[e])
            //ans += x;
        //cout << ans << '\n';
    }
}

Compilation message

Main.cpp: In function 'node* add_last(node*, ll)':
Main.cpp:45:8: warning: unused variable 'ndsm' [-Wunused-variable]
   45 |     ll ndsm = get_sum(nd);
      |        ^~~~
# 결과 실행 시간 메모리 Grader output
1 Correct 2 ms 2772 KB Output is correct
2 Correct 2 ms 2644 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 2 ms 2772 KB Output is correct
2 Correct 2 ms 2644 KB Output is correct
3 Correct 4 ms 4220 KB Output is correct
4 Correct 6 ms 4696 KB Output is correct
5 Correct 13 ms 12840 KB Output is correct
6 Correct 2 ms 3156 KB Output is correct
7 Correct 3 ms 3668 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 2 ms 2772 KB Output is correct
2 Correct 2 ms 2644 KB Output is correct
3 Correct 4 ms 4220 KB Output is correct
4 Correct 6 ms 4696 KB Output is correct
5 Correct 13 ms 12840 KB Output is correct
6 Correct 2 ms 3156 KB Output is correct
7 Correct 3 ms 3668 KB Output is correct
8 Correct 29 ms 23396 KB Output is correct
9 Correct 36 ms 38044 KB Output is correct
10 Correct 8 ms 6996 KB Output is correct
11 Correct 437 ms 467336 KB Output is correct
12 Correct 11 ms 10068 KB Output is correct
# 결과 실행 시간 메모리 Grader output
1 Correct 2 ms 2772 KB Output is correct
2 Correct 2 ms 2644 KB Output is correct
3 Correct 4 ms 4220 KB Output is correct
4 Correct 6 ms 4696 KB Output is correct
5 Correct 13 ms 12840 KB Output is correct
6 Correct 2 ms 3156 KB Output is correct
7 Correct 3 ms 3668 KB Output is correct
8 Correct 29 ms 23396 KB Output is correct
9 Correct 36 ms 38044 KB Output is correct
10 Correct 8 ms 6996 KB Output is correct
11 Correct 437 ms 467336 KB Output is correct
12 Correct 11 ms 10068 KB Output is correct
13 Correct 129 ms 77544 KB Output is correct
14 Correct 235 ms 218616 KB Output is correct
15 Correct 19 ms 12316 KB Output is correct
16 Runtime error 494 ms 524288 KB Execution killed with signal 9
17 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 369 ms 164864 KB Output is correct
2 Correct 356 ms 142560 KB Output is correct
3 Correct 173 ms 73056 KB Output is correct
4 Execution timed out 632 ms 524288 KB Time limit exceeded
5 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 2 ms 2772 KB Output is correct
2 Correct 2 ms 2644 KB Output is correct
3 Correct 4 ms 4220 KB Output is correct
4 Correct 6 ms 4696 KB Output is correct
5 Correct 13 ms 12840 KB Output is correct
6 Correct 2 ms 3156 KB Output is correct
7 Correct 3 ms 3668 KB Output is correct
8 Correct 29 ms 23396 KB Output is correct
9 Correct 36 ms 38044 KB Output is correct
10 Correct 8 ms 6996 KB Output is correct
11 Correct 437 ms 467336 KB Output is correct
12 Correct 11 ms 10068 KB Output is correct
13 Correct 129 ms 77544 KB Output is correct
14 Correct 235 ms 218616 KB Output is correct
15 Correct 19 ms 12316 KB Output is correct
16 Runtime error 494 ms 524288 KB Execution killed with signal 9
17 Halted 0 ms 0 KB -