This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
using ll = long long;
struct node {
node* l = nullptr;
node* r = nullptr;
int prior;
ll val = 0;
ll sm = 0;
int sz = 0;
};
node* copy(node* nd) {
node* v = new node();
v->l = nd->l;
v->r = nd->r;
v->val = nd->val;
v->sm = nd->sm;
v->prior = nd->prior;
v->sz = nd->sz;
return v;
}
int get_size(node* nd) {
if (!nd)
return 0;
return nd->sz;
}
ll get_sum(node* nd) {
if (!nd)
return 0;
return nd->sm;
}
void update(node* nd) {
nd->sz = 1 + get_size(nd->l) + get_size(nd->r);
nd->sm = nd->val + get_sum(nd->l) + get_sum(nd->r);
}
node* add_last(node* nd, ll val) {
ll ndsm = get_sum(nd);
nd = copy(nd);
if (!nd->r) {
nd->val += val;
nd->sm += val;
//cout << "( " << ndsm << " + " << val << " -> " << get_sum(nd) << ") ";
return nd;
}
nd->r = add_last(nd->r, val);
update(nd);
return nd;
}
node* merge(node* l, node* r) {
if (!l)
return r;
if (!r)
return l;
if (l->prior > r->prior) {
auto lr = merge(l->r, r);
l = copy(l);
l->r = lr;
update(l);
return l;
} else {
auto rl = merge(l, r->l);
r = copy(r);
r->l = rl;
update(r);
return r;
}
}
pair<node*, node*> split_value(node* nd, ll val) {
if (!nd)
return {nullptr, nullptr};
if (val <= nd->val) {
auto [l, r] = split_value(nd->l, val);
nd = copy(nd);
nd->l = r;
update(nd);
return {l, nd};
} else {
auto [l, r] = split_value(nd->r, val);
nd = copy(nd);
nd->r = l;
update(nd);
return {nd, r};
}
}
pair<node*, node*> split_size(node* nd, int sz) {
if (!nd)
return {nullptr, nullptr};
if (sz <= get_size(nd->l)) {
auto [l, r] = split_size(nd->l, sz);
nd = copy(nd);
nd->l = r;
update(nd);
return {l, nd};
} else {
auto [l, r] = split_size(nd->r, sz - 1 - get_size(nd->l));
nd = copy(nd);
nd->r = l;
update(nd);
return {nd, r};
}
}
node* insert(node* nd, ll val) {
auto [l, r] = split_value(nd, val);
node* v = new node();
v->val = v->sm = val;
v->sz = 1;
v->prior = rand();
return merge(l, merge(v, r));
}
void get_values(node* nd, vector<ll>& ans) {
if (!nd)
return;
get_values(nd->r, ans);
ans.push_back(nd->val);
get_values(nd->l, ans);
}
const int MAXN = 1e5 + 123;
vector<pair<int, int>> edg[MAXN];
ll cost[MAXN * 3];
pair<int, int> edges[MAXN * 3];
node* dp[MAXN * 3];
//vector<ll> dp[MAXN * 3];
node* merge_nodes(vector<node*> nds, int k) {
node* basend = nullptr;
for (auto nd : nds) {
if (get_size(basend) < get_size(nd))
basend = nd;
}
node* ans = basend;
for (auto nd : nds) {
if (nd == basend)
continue;
vector<ll> add;
get_values(nd, add);
for (auto x : add) {
ans = insert(ans, x);
}
}
if (get_size(ans) > k) {
vector<ll> vals;
get_values(ans, vals);
ans = split_size(ans, get_size(ans) - k).second;
vals.clear();
get_values(ans, vals);
}
return ans;
}
void get_dp(int e, int k) {
if (get_size(dp[e]))
return;
auto [rt, v] = edges[e];
//if (k == 1) {
//dp[e] = {0};
//}
vector<node*> nds;
for (auto [u, ne] : edg[v]) {
if (u == rt)
continue;
get_dp(ne, k);
nds.push_back(dp[ne]);
//if (k == 1) {
//dp[e][0] = max(dp[e][0], dp[ne][0]);
//continue;
//}
//vector<ll> ndp(dp[e].size() + dp[ne].size());
//merge(dp[e].rbegin(), dp[e].rend(), dp[ne].rbegin(), dp[ne].rend(), ndp.rbegin());
//ndp.resize(min(int(ndp.size()), k));
//dp[e] = ndp;
}
//cout << "# ";
//for (auto nd : nds)
//cout << get_sum(nd) << ' ';
dp[e] = merge_nodes(nds, k);
//cout << ": " << get_sum(dp[e]);
if (get_size(dp[e])) {
//cout << get_sum(dp[e]) << " + " << cost[e] << " -> ";
dp[e] = add_last(dp[e], cost[e]);
//cout << get_sum(dp[e]) << endl;
//cout << "? " << cost[e] << ' ' << get_sum(dp[e]) << '\n';
} else {
dp[e] = insert(nullptr, cost[e]);
//cout << "! " << cost[e] << ' ' << get_sum(dp[e]) << '\n';
//dp[e] = {cost[e]};
}
//cout << " | " << get_sum(dp[e]) << endl;
}
signed main() {
ios_base::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
int n, k;
cin >> n >> k;
//assert(k == 1);
for (int i = 0; i + 1 < n; ++i) {
int u, v, c;
cin >> u >> v >> c;
--u, --v;
edg[u].push_back({v, i * 2});
edg[v].push_back({u, i * 2 + 1});
edges[i * 2] = {u, v};
edges[i * 2 + 1] = {v, u};
cost[i * 2] = cost[i * 2 + 1] = c;
}
for (int e = 2 * n; e < 3 * n; ++e) {
edges[e] = {n, e - 2 * n};
get_dp(e, k);
cout << get_sum(dp[e]) << '\n';
//ll ans = 0;
//for (auto x : dp[e])
//ans += x;
//cout << ans << '\n';
}
}
Compilation message (stderr)
Main.cpp: In function 'node* add_last(node*, ll)':
Main.cpp:45:8: warning: unused variable 'ndsm' [-Wunused-variable]
45 | ll ndsm = get_sum(nd);
| ^~~~
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |