Submission #839568

#TimeUsernameProblemLanguageResultExecution timeMemory
839568model_codeTruck Driver (IOI23_deliveries)C++17
100 / 100
1439 ms39968 KiB
// correct/sol_na_full.cpp

#include "deliveries.h"
#include <cstdlib>
#include <iostream>
#include <cassert>
#include<cmath>

using std::cerr;

#define xx first
#define yy second

using ll = long long;

class solver {
    struct segtree {
        struct node {
            ll max_subtree_w;
            ll sum_wdotd, sum_d;
            ll lazy;

            node() : max_subtree_w(0), sum_wdotd(0), sum_d(0), lazy(0) {}

            void apply_lazy() {
                max_subtree_w += lazy;
                sum_wdotd += sum_d * lazy;
            }

            node operator+(const node& other) const {
                node res = *this;
                res.max_subtree_w = std::max(res.max_subtree_w, other.max_subtree_w);
                res.sum_wdotd = res.sum_wdotd + other.sum_wdotd;
                res.sum_d = res.sum_d + other.sum_d;

                return res;
            }
        };

        std::vector<node> tree;

        segtree() {}
        segtree(int N) {
            tree.resize(4*N);
        }

        void build(int ind, int L, int R, std::vector<ll>& W, std::vector<ll>& D) {
            if(L==R) {
                tree[ind].max_subtree_w = W[L];
                tree[ind].sum_d = D[L];
                tree[ind].sum_wdotd = (ll)D[L]*W[L];
                return ;
            }else {
                build(2*ind, L, (L+R)/2, W, D);
                build(2*ind+1, (L+R)/2+1, R, W, D);

                tree[ind] = std::move(tree[2*ind]+tree[2*ind+1]);
                return ;
            }
        }

        void push(int ind, int L, int R) {
            if(tree[ind].lazy!=0) {
                if(L!=R) {
                    tree[2*ind].lazy+=tree[ind].lazy;
                    tree[2*ind+1].lazy+=tree[ind].lazy;
                }

                tree[ind].apply_lazy();
                tree[ind].lazy=0;
            }
        }
        
        node query(int ind, int L, int R, int i, int j) {
            push(ind, L, R);
            if(R<i || j<L) return node();
            if(i<=L && R<=j) return tree[ind];
            int mid=(L+R)/2;
            if(mid<i) 
                return query(2*ind+1, mid+1, R, i, j);
            else if(j<=mid)
                return query(2*ind, L, mid, i, j);
            else
                return query(2*ind, L, mid, i, j) + query(2*ind+1, mid+1, R, i, j);
        }
        
        void update(int ind, int L, int R, int i, int j, ll by) {
            push(ind, L, R);
            if(R<i || j<L) return ;
            if(i<=L && R<=j) {
                tree[ind].lazy+=by;
                push(ind, L, R);
                return ;
            }
            update(2*ind, L, (L+R)/2, i, j, by);
            update(2*ind+1, (L+R)/2+1, R, i, j, by);
            tree[ind]=std::move(tree[2*ind]+tree[2*ind+1]);
        }

        int find_last(int ind, int L, int R, ll val) {
            push(ind, L, R);
            if(L!=R) {
                push(2*ind, L, (L+R)/2);
                push(2*ind+1, (L+R)/2+1, R);
            }

            if(L==R) return L;
            if(2*tree[2*ind+1].max_subtree_w>=val) {
                return find_last(2*ind+1, (L+R)/2+1, R, val);
            }
            return find_last(2*ind, L, (L+R)/2, val);
        }
    };

    int N;
    std::vector<std::vector<std::pair<int,int>>> adj;
    std::vector<int> W;

    std::vector<int> sz, hld_nxt, par, par_D;
    std::vector<ll> subtree_sum;
    void dfs_sz(int x) {
        par[x]=-1;
        subtree_sum[x]=W[x];
        sz[x]=1;

        hld_nxt[x]=-1;
        
        for(auto i:adj[x]) {
            if(!sz[i.xx]) {
                dfs_sz(i.xx);
                
                par_D[i.xx]=i.yy;
                subtree_sum[x]+=subtree_sum[i.xx];
                par[i.xx]=x;
                sz[x]+=sz[i.xx];

                if(hld_nxt[x]==-1 || sz[i.xx]>sz[hld_nxt[x]]) {
                    hld_nxt[x]=i.xx;
                }
            }
        }
    }

    std::vector<int> hld, hld_id, hld_head, hld_inv;
    int hld_pos, hld_next_id;
    void dfs_hld(int x) {
        hld[x]=hld_pos++;
        hld_inv[hld_pos-1]=x;
        hld_id[x]=hld_next_id;
        if(hld_nxt[x]!=-1) {
            dfs_hld(hld_nxt[x]);
            
            for(auto i:adj[x]) {
                if(hld_nxt[x]!=i.xx && par[i.xx]==x) {
                    hld_next_id++;
                    dfs_hld(i.xx);
                }
            }
        }
        
        hld_head[hld_id[x]]=x;
    }

    segtree st;
    ll sum_w=0;
    int lg;
public:
    solver() {}

    solver(int N, std::vector<int> U_, std::vector<int> V_, std::vector<int> T_, std::vector<int> W_) : N(N), hld_pos(0), hld_next_id(0) {
        lg = log2(N)+1;
        adj.resize(N);
        for(int i=0;i<N-1;++i) {
            adj[U_[i]].push_back({V_[i], T_[i]});
            adj[V_[i]].push_back({U_[i], T_[i]});
        }
        W = std::move(W_);
        W[0]++;
        for(int w:W) sum_w+=w;

        subtree_sum.assign(N, 0);
        sz.assign(N, 0);
        hld_nxt.assign(N, 0);
        par.assign(N, -1);
        par_D.assign(N, 0);

        dfs_sz(0);
        hld.assign(N, -1);
        hld_id.assign(N, -1);
        hld_inv.assign(N, -1);
        hld_head.assign(N, -1);
        dfs_hld(0);

        st = segtree(N);
        std::vector<ll> subtree_sum_hld(N), par_D_hld(N);
        for(int i=0;i<N;++i) {
            subtree_sum_hld[hld[i]]=subtree_sum[i];
            par_D_hld[hld[i]]=par_D[i];
        }

        st.build(1,0,N-1, subtree_sum_hld, par_D_hld);
    }

    // (sum_w-subtreeSumW[i])*par_d[i] a centroidig a többi subtreeSumW[i]*par_d[i]
    // 3 * 0 + 1 * 1 - 
    ll calc_answer(ll p) {
        ll ans = st.query(1, 0, N-1, 0, N-1).sum_wdotd;
        while(1) {
            ans += -2*st.query(1, 0, N-1,  hld[hld_head[hld_id[p]]], hld[p]).sum_wdotd
                   + st.query(1, 0, N-1,  hld[hld_head[hld_id[p]]], hld[p]).sum_d * sum_w;
            p = par[hld_head[hld_id[p]]];
            if(p<0) break ;
        }

        return ans;
    }

    ll change(ll p, ll new_w) {
        if(p==0) new_w++;
        ll prv_w = W[p];
        ll change = new_w - prv_w;
        W[p] = new_w;
        sum_w+=change;
        while(1) {
            st.update(1, 0, N-1, hld[hld_head[hld_id[p]]], hld[p], change);
            p = par[hld_head[hld_id[p]]];
            if(p<0) break ;
        }
        
        return 2*calc_answer(hld_inv[st.find_last(1,0,N-1,sum_w)]);
    }
};

solver s;

void init(int N, std::vector<int> U, std::vector<int> V, std::vector<int> T, std::vector<int> W) {
    s = solver(N, U, V, T, W);
}

long long max_time(int X, int Y) {
    return s.change(X, Y);
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...