Submission #839568

#TimeUsernameProblemLanguageResultExecution timeMemory
839568model_codeTruck Driver (IOI23_deliveries)C++17
100 / 100
1439 ms39968 KiB
// correct/sol_na_full.cpp #include "deliveries.h" #include <cstdlib> #include <iostream> #include <cassert> #include<cmath> using std::cerr; #define xx first #define yy second using ll = long long; class solver { struct segtree { struct node { ll max_subtree_w; ll sum_wdotd, sum_d; ll lazy; node() : max_subtree_w(0), sum_wdotd(0), sum_d(0), lazy(0) {} void apply_lazy() { max_subtree_w += lazy; sum_wdotd += sum_d * lazy; } node operator+(const node& other) const { node res = *this; res.max_subtree_w = std::max(res.max_subtree_w, other.max_subtree_w); res.sum_wdotd = res.sum_wdotd + other.sum_wdotd; res.sum_d = res.sum_d + other.sum_d; return res; } }; std::vector<node> tree; segtree() {} segtree(int N) { tree.resize(4*N); } void build(int ind, int L, int R, std::vector<ll>& W, std::vector<ll>& D) { if(L==R) { tree[ind].max_subtree_w = W[L]; tree[ind].sum_d = D[L]; tree[ind].sum_wdotd = (ll)D[L]*W[L]; return ; }else { build(2*ind, L, (L+R)/2, W, D); build(2*ind+1, (L+R)/2+1, R, W, D); tree[ind] = std::move(tree[2*ind]+tree[2*ind+1]); return ; } } void push(int ind, int L, int R) { if(tree[ind].lazy!=0) { if(L!=R) { tree[2*ind].lazy+=tree[ind].lazy; tree[2*ind+1].lazy+=tree[ind].lazy; } tree[ind].apply_lazy(); tree[ind].lazy=0; } } node query(int ind, int L, int R, int i, int j) { push(ind, L, R); if(R<i || j<L) return node(); if(i<=L && R<=j) return tree[ind]; int mid=(L+R)/2; if(mid<i) return query(2*ind+1, mid+1, R, i, j); else if(j<=mid) return query(2*ind, L, mid, i, j); else return query(2*ind, L, mid, i, j) + query(2*ind+1, mid+1, R, i, j); } void update(int ind, int L, int R, int i, int j, ll by) { push(ind, L, R); if(R<i || j<L) return ; if(i<=L && R<=j) { tree[ind].lazy+=by; push(ind, L, R); return ; } update(2*ind, L, (L+R)/2, i, j, by); update(2*ind+1, (L+R)/2+1, R, i, j, by); tree[ind]=std::move(tree[2*ind]+tree[2*ind+1]); } int find_last(int ind, int L, int R, ll val) { push(ind, L, R); if(L!=R) { push(2*ind, L, (L+R)/2); push(2*ind+1, (L+R)/2+1, R); } if(L==R) return L; if(2*tree[2*ind+1].max_subtree_w>=val) { return find_last(2*ind+1, (L+R)/2+1, R, val); } return find_last(2*ind, L, (L+R)/2, val); } }; int N; std::vector<std::vector<std::pair<int,int>>> adj; std::vector<int> W; std::vector<int> sz, hld_nxt, par, par_D; std::vector<ll> subtree_sum; void dfs_sz(int x) { par[x]=-1; subtree_sum[x]=W[x]; sz[x]=1; hld_nxt[x]=-1; for(auto i:adj[x]) { if(!sz[i.xx]) { dfs_sz(i.xx); par_D[i.xx]=i.yy; subtree_sum[x]+=subtree_sum[i.xx]; par[i.xx]=x; sz[x]+=sz[i.xx]; if(hld_nxt[x]==-1 || sz[i.xx]>sz[hld_nxt[x]]) { hld_nxt[x]=i.xx; } } } } std::vector<int> hld, hld_id, hld_head, hld_inv; int hld_pos, hld_next_id; void dfs_hld(int x) { hld[x]=hld_pos++; hld_inv[hld_pos-1]=x; hld_id[x]=hld_next_id; if(hld_nxt[x]!=-1) { dfs_hld(hld_nxt[x]); for(auto i:adj[x]) { if(hld_nxt[x]!=i.xx && par[i.xx]==x) { hld_next_id++; dfs_hld(i.xx); } } } hld_head[hld_id[x]]=x; } segtree st; ll sum_w=0; int lg; public: solver() {} solver(int N, std::vector<int> U_, std::vector<int> V_, std::vector<int> T_, std::vector<int> W_) : N(N), hld_pos(0), hld_next_id(0) { lg = log2(N)+1; adj.resize(N); for(int i=0;i<N-1;++i) { adj[U_[i]].push_back({V_[i], T_[i]}); adj[V_[i]].push_back({U_[i], T_[i]}); } W = std::move(W_); W[0]++; for(int w:W) sum_w+=w; subtree_sum.assign(N, 0); sz.assign(N, 0); hld_nxt.assign(N, 0); par.assign(N, -1); par_D.assign(N, 0); dfs_sz(0); hld.assign(N, -1); hld_id.assign(N, -1); hld_inv.assign(N, -1); hld_head.assign(N, -1); dfs_hld(0); st = segtree(N); std::vector<ll> subtree_sum_hld(N), par_D_hld(N); for(int i=0;i<N;++i) { subtree_sum_hld[hld[i]]=subtree_sum[i]; par_D_hld[hld[i]]=par_D[i]; } st.build(1,0,N-1, subtree_sum_hld, par_D_hld); } // (sum_w-subtreeSumW[i])*par_d[i] a centroidig a többi subtreeSumW[i]*par_d[i] // 3 * 0 + 1 * 1 - ll calc_answer(ll p) { ll ans = st.query(1, 0, N-1, 0, N-1).sum_wdotd; while(1) { ans += -2*st.query(1, 0, N-1, hld[hld_head[hld_id[p]]], hld[p]).sum_wdotd + st.query(1, 0, N-1, hld[hld_head[hld_id[p]]], hld[p]).sum_d * sum_w; p = par[hld_head[hld_id[p]]]; if(p<0) break ; } return ans; } ll change(ll p, ll new_w) { if(p==0) new_w++; ll prv_w = W[p]; ll change = new_w - prv_w; W[p] = new_w; sum_w+=change; while(1) { st.update(1, 0, N-1, hld[hld_head[hld_id[p]]], hld[p], change); p = par[hld_head[hld_id[p]]]; if(p<0) break ; } return 2*calc_answer(hld_inv[st.find_last(1,0,N-1,sum_w)]); } }; solver s; void init(int N, std::vector<int> U, std::vector<int> V, std::vector<int> T, std::vector<int> W) { s = solver(N, U, V, T, W); } long long max_time(int X, int Y) { return s.change(X, Y); }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...