Submission #716607

#TimeUsernameProblemLanguageResultExecution timeMemory
716607StickfishPaths (RMI21_paths)C++17
36 / 100
632 ms524288 KiB
#include <iostream> #include <vector> #include <algorithm> using namespace std; using ll = long long; struct node { node* l = nullptr; node* r = nullptr; int prior; ll val = 0; ll sm = 0; int sz = 0; }; node* copy(node* nd) { node* v = new node(); v->l = nd->l; v->r = nd->r; v->val = nd->val; v->sm = nd->sm; v->prior = nd->prior; v->sz = nd->sz; return v; } int get_size(node* nd) { if (!nd) return 0; return nd->sz; } ll get_sum(node* nd) { if (!nd) return 0; return nd->sm; } void update(node* nd) { nd->sz = 1 + get_size(nd->l) + get_size(nd->r); nd->sm = nd->val + get_sum(nd->l) + get_sum(nd->r); } node* add_last(node* nd, ll val) { ll ndsm = get_sum(nd); nd = copy(nd); if (!nd->r) { nd->val += val; nd->sm += val; //cout << "( " << ndsm << " + " << val << " -> " << get_sum(nd) << ") "; return nd; } nd->r = add_last(nd->r, val); update(nd); return nd; } node* merge(node* l, node* r) { if (!l) return r; if (!r) return l; if (l->prior > r->prior) { auto lr = merge(l->r, r); l = copy(l); l->r = lr; update(l); return l; } else { auto rl = merge(l, r->l); r = copy(r); r->l = rl; update(r); return r; } } pair<node*, node*> split_value(node* nd, ll val) { if (!nd) return {nullptr, nullptr}; if (val <= nd->val) { auto [l, r] = split_value(nd->l, val); nd = copy(nd); nd->l = r; update(nd); return {l, nd}; } else { auto [l, r] = split_value(nd->r, val); nd = copy(nd); nd->r = l; update(nd); return {nd, r}; } } pair<node*, node*> split_size(node* nd, int sz) { if (!nd) return {nullptr, nullptr}; if (sz <= get_size(nd->l)) { auto [l, r] = split_size(nd->l, sz); nd = copy(nd); nd->l = r; update(nd); return {l, nd}; } else { auto [l, r] = split_size(nd->r, sz - 1 - get_size(nd->l)); nd = copy(nd); nd->r = l; update(nd); return {nd, r}; } } node* insert(node* nd, ll val) { auto [l, r] = split_value(nd, val); node* v = new node(); v->val = v->sm = val; v->sz = 1; v->prior = rand(); return merge(l, merge(v, r)); } void get_values(node* nd, vector<ll>& ans) { if (!nd) return; get_values(nd->r, ans); ans.push_back(nd->val); get_values(nd->l, ans); } const int MAXN = 1e5 + 123; vector<pair<int, int>> edg[MAXN]; ll cost[MAXN * 3]; pair<int, int> edges[MAXN * 3]; node* dp[MAXN * 3]; //vector<ll> dp[MAXN * 3]; node* merge_nodes(vector<node*> nds, int k) { node* basend = nullptr; for (auto nd : nds) { if (get_size(basend) < get_size(nd)) basend = nd; } node* ans = basend; for (auto nd : nds) { if (nd == basend) continue; vector<ll> add; get_values(nd, add); for (auto x : add) { ans = insert(ans, x); } } if (get_size(ans) > k) { vector<ll> vals; get_values(ans, vals); ans = split_size(ans, get_size(ans) - k).second; vals.clear(); get_values(ans, vals); } return ans; } void get_dp(int e, int k) { if (get_size(dp[e])) return; auto [rt, v] = edges[e]; //if (k == 1) { //dp[e] = {0}; //} vector<node*> nds; for (auto [u, ne] : edg[v]) { if (u == rt) continue; get_dp(ne, k); nds.push_back(dp[ne]); //if (k == 1) { //dp[e][0] = max(dp[e][0], dp[ne][0]); //continue; //} //vector<ll> ndp(dp[e].size() + dp[ne].size()); //merge(dp[e].rbegin(), dp[e].rend(), dp[ne].rbegin(), dp[ne].rend(), ndp.rbegin()); //ndp.resize(min(int(ndp.size()), k)); //dp[e] = ndp; } //cout << "# "; //for (auto nd : nds) //cout << get_sum(nd) << ' '; dp[e] = merge_nodes(nds, k); //cout << ": " << get_sum(dp[e]); if (get_size(dp[e])) { //cout << get_sum(dp[e]) << " + " << cost[e] << " -> "; dp[e] = add_last(dp[e], cost[e]); //cout << get_sum(dp[e]) << endl; //cout << "? " << cost[e] << ' ' << get_sum(dp[e]) << '\n'; } else { dp[e] = insert(nullptr, cost[e]); //cout << "! " << cost[e] << ' ' << get_sum(dp[e]) << '\n'; //dp[e] = {cost[e]}; } //cout << " | " << get_sum(dp[e]) << endl; } signed main() { ios_base::sync_with_stdio(false); cin.tie(nullptr); cout.tie(nullptr); int n, k; cin >> n >> k; //assert(k == 1); for (int i = 0; i + 1 < n; ++i) { int u, v, c; cin >> u >> v >> c; --u, --v; edg[u].push_back({v, i * 2}); edg[v].push_back({u, i * 2 + 1}); edges[i * 2] = {u, v}; edges[i * 2 + 1] = {v, u}; cost[i * 2] = cost[i * 2 + 1] = c; } for (int e = 2 * n; e < 3 * n; ++e) { edges[e] = {n, e - 2 * n}; get_dp(e, k); cout << get_sum(dp[e]) << '\n'; //ll ans = 0; //for (auto x : dp[e]) //ans += x; //cout << ans << '\n'; } }

Compilation message (stderr)

Main.cpp: In function 'node* add_last(node*, ll)':
Main.cpp:45:8: warning: unused variable 'ndsm' [-Wunused-variable]
   45 |     ll ndsm = get_sum(nd);
      |        ^~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...