Submission #839570

#TimeUsernameProblemLanguageResultExecution timeMemory
839570model_codeTruck Driver (IOI23_deliveries)C++17
100 / 100
1576 ms41688 KiB
// correct/sol_vp_full_hld.cpp

#include "deliveries.h"
#include <algorithm>
#include <vector>
#include <numeric>
#include <assert.h>
#include <iostream>

using namespace std;

struct SegmentTree{
    int n;
    vector<long long> st, add, w, sum;

    SegmentTree(int n_) : n(1 << (32 - __builtin_clz(n_))), st(n*2, 0), add(n*2, 0), w(2*n, 0), sum(2*n, 0) {}

    void set_weights(const vector<long long> &weights){
        assert((int)weights.size() <= n);
        fill(w.begin(), w.end(), 0);
        copy(weights.begin(), weights.end(), w.begin() + n);
        for(int i = n-1; i > 0; i--){
            w[i] = w[2*i] + w[2*i+1];
        }
    }

    void upd(int x){
        if(x < n) {
            st[x] = max(st[2*x], st[2*x+1]) + add[x];
            sum[x] = sum[2*x] + sum[2*x+1] + w[x] * add[x];
        } else {
            st[x] = add[x];
            sum[x] = add[x] * w[x];
        }
    }

    void update(int L, int R, long long c, int x, int l, int r){
        if(L <= l && r <= R){
            add[x] += c;
            upd(x);
            return;
        }
        if(r <= L || R <= l) { return; }
        int m = (l + r) / 2;
        update(L, R, c, 2*x, l, m);
        update(L, R, c, 2*x+1, m, r);
        upd(x);
    }
    void update(int l, int r, long long c) { update(l, r, c, 1, 0, n); }

    long long query(int x){
        long long res = 0;
        for(int i = x + n; i > 0; i >>= 1) { res += add[i]; }
        return res;
    }
    long long query_w(int L, int R, int x, int l, int r, long long s){
        if(L <= l && r <= R){
            return sum[x] + s * w[x];
        }
        if(r <= L || R <= l) return 0;
        s += add[x];
        int m = (l + r) / 2;
        return query_w(L, R, 2*x, l, m, s) + query_w(L, R, 2*x+1, m, r, s);
    }
    long long query_w(int l, int r) { return query_w(l, r, 1, 0, n, 0); }

    long long get_last(long long N, int x, int l, int r){
        if(l+1 == r) return l;
        N -= add[x];
        int m = (l + r) / 2;
        if(st[2*x+1] >= N) return get_last(N, 2*x+1, m, r);
        return get_last(N, 2*x, l, m);
    }
    long long get_last(long long N) { return get_last(N, 1, 0, n); }
};

struct node{
    vector<pair<int, long long>> edges;
    long long dist, w;
    int sz, up, pos, L, M, R;
};

vector<node> g;
vector<int> inv, cnt;
vector<bool> vis;
int n;

SegmentTree st(1e5);

void dfs(int x, long long len = 0, long long last = 0){
    vis[x] = true;
    g[x].dist = len;
    g[x].sz = 1;
    g[x].w = last;
    for(auto [y, w] : g[x].edges){
        if(!vis[y]){
            dfs(y, len + w, w);
            g[x].sz += g[y].sz;
        }
    }
}

int cur = 0;
void hld(int x, int up, int L){
    vis[x] = true;
    g[x].up = up;
    g[x].L = L;
    g[x].pos = cur++;
    inv.push_back(x);
    
    nth_element(g[x].edges.begin(), g[x].edges.begin() + 1, g[x].edges.end(), [](auto a, auto b) -> bool { return g[a.first].sz > g[b.first].sz; });

    g[x].M = -1;
    for(auto [y, w] : g[x].edges){
        if(!vis[y]){
            if(g[x].M == -1){
                hld(y, up, L);
                g[x].M = g[y].M;
            } else{
                hld(y, x, cur);
            }
        }
    }
    if(g[x].M == -1) { g[x].M = g[x].pos + 1; }
    g[x].R = cur;
}

void update_hld(SegmentTree &st, int x, long long c){
    if(x == -1) return;
    st.update(g[x].L, g[x].pos+1, c);
    update_hld(st, g[x].up, c);
}

long long queryw_hld(SegmentTree &st, int x){
    if(x == -1) return 0;
    return st.query_w(g[x].L, g[x].pos + 1) + queryw_hld(st, g[x].up);
}

void init(int N, std::vector<int> U, std::vector<int> V, std::vector<int> T, std::vector<int> W) {
    n = N;
	g.resize(N);
    for(int i = 0; i < N-1; i++){
        g[U[i]].edges.emplace_back(V[i], T[i]);
        g[V[i]].edges.emplace_back(U[i], T[i]);
    }

    vis.assign(N, false);
    dfs(0);
    vis.assign(N, false);
    hld(0, -1, 0);

    cnt = W;
    cnt[0]++;

    vector<long long> weights(N);
    for(int i = 0; i < N; i++){
        weights[g[i].pos] = g[i].w;
    }

    st.set_weights(weights);

    for(int i = 0; i < N; i++){
        update_hld(st, i, cnt[i]);
    }

    return;
}

long long max_time(int S, int X) {
    if(S == 0) X++;
    update_hld(st, S, X - cnt[S]);
    cnt[S] = X;

    long long sum_cnt = st.query(0);
    int idx = sum_cnt == 0 ? 0 : st.get_last((sum_cnt+1)/2);
    long long path = queryw_hld(st, inv[idx]);
    long long sum = st.query_w(0, n);

	return (sum - path * 2 + g[inv[idx]].dist * sum_cnt) * 2;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...