Submission #1075916

#TimeUsernameProblemLanguageResultExecution timeMemory
1075916_8_8_Petrol stations (CEOI24_stations)C++17
100 / 100
1191 ms275256 KiB
#include <bits/stdc++.h>

using namespace std;
    
typedef long long ll;
const int  N = 7e4 + 12, MOD = 998244353;

mt19937 rng(123231);
struct node{
    node *l = 0,*r = 0;
    ll prior = 0,rem = 0,col = 0,sum =0 ,add = 0;
    node(ll val,ll c){
        prior = rng();
        rem = val;
        col = sum = c;
    }
    node(){
        prior = 0;
        rem = 0;
        col = 0;
        sum = 0;
        add = 0;
    }
};
using pnode = node *;
int sum(pnode v) {
    return v ? v->sum : 0;
}
void upd(pnode x) {
    x->sum = sum(x->l) + sum(x->r) + x->col;
}
void inc(pnode v,int val) {
    if(v) {
        v->rem += val;
        v->add += val;
    }
}
void push(pnode v) {
    inc(v->l,v->add);
    inc(v->r,v->add);
    v->add = 0;
}
pnode merge(pnode a,pnode b) {
    if(!a) return b;
    if(!b) return a;
    if(a->prior < b->prior) {
        push(b);
        b->l = merge(a,b->l);
        upd(b);
        return b;
    } else {
        push(a);
        a->r = merge(a->r,b);
        upd(a);
        return a;
    }
}
pair<pnode,pnode> split(pnode v,int key) {
    if(!v) return {0,0};
    push(v);
    if(v->rem < key) {
        auto [l,r] = split(v->r,key);
        v->r = l;
        upd(v);
        return {v,r};
    } else {
        auto [l,r] = split(v->l,key);
        v->l = r;
        upd(v);
        return {l,v};
    }   
}
vector<pair<int,int>> g[N];
vector<int> G[N];
int n, k, s[N];
bool blocked[N];
void prec(int v,int pr = -1) {
    s[v] = 1;
    for(int to:G[v]) {
        if(to == pr || blocked[to]) continue;
        prec(to,v);
        s[v] += s[to];
    }
}
int find(int v,int pr,int total) {
    for(int to:G[v]) {
        if(to == pr || blocked[to]) continue;
        if(s[to] > total / 2) {
            return find(to,v,total);
        }
    }
    return v;
}
ll ans[N];
pnode root;
int rem[N],cc[N];
void ins(int val) {
    auto [l,r] = split(root,val);
    pnode nv = new node(val,1);
    root = merge(merge(l,nv),r);
}
void cl(int v,int pr) {
    cc[v] = 0;
    for(int to:G[v]) if(!blocked[to] && to != pr) {
        cl(to,v);
    }
}
pair<pnode,pnode> split1(pnode v) {
    if(!v) return {0,0};
    push(v);
    if(v->r) {
        auto [l,r] = split1(v->r);
        v->r = l;
        upd(v);
        return {v,r};
    } else {
        auto t = v->l;
        v->l = nullptr;
        return {t,v};
    }
}
void go(int v,int pr,ll W) {
    auto [l,r] = split(root,W);
    int _s = sum(l);
    inc(r,-W);
    root = merge(r,new node(k - W,_s));
    ans[pr] += _s * 1ll * s[v];
    for(auto [to,w]:g[v]) {
        if(to == pr || blocked[to]) continue;
        go(to,v,w);
    }
    auto [t,f] = split1(root);
    inc(t,W);
    root = merge(l,t);
}
int total,pt;
bool REV = false;
void add(int v,int pr, vector<pair<int,ll>> &st) {
    ll c = k;
    rem[v] = -1;
    int vert = -1;
    // for(int i = (int)st.size() - 1;i >= 0;i--) {
    //     c -= st[i].second;
    //     if(c < 0) {
    //         rem[v] = rem[st[i + 1].first];
    //         vert = st[i +1].first;
    //         break;
    //     }
    // }
    int l = -1,r = (int)st.size() - 1;
    while(r - l > 1) {
        int mid = (l +r) >> 1;
        ll val = k - (st.back().second - (mid ? st[mid - 1].second : 0));
        if(val >= 0) {
            r = mid;
        } else {
            l = mid;
        }
    }
    if(r == 0) {
        rem[v] = k - st.back().second;
    } else {
        vert = st[r].first;
        rem[v] = rem[vert];
    }
    // cout << v << ' ' << rem[v] << ' ' << vert << ' ' << "| " << r << '\n';
    assert(rem[v] >= 0);
    ins(rem[v]);  
    cc[v]++;
    for(auto [to,w]:g[v]) {
        if(to == pr || blocked[to]) continue;
        st.push_back({v,w + st.back().second});
        add(to,v,st);
        st.pop_back();
    }
    if(vert != -1) {
        cc[vert] += cc[v];
        if(!REV) ans[vert] += cc[v] * (total - s[pt]);
    }
}
void decompose(int v) {
    prec(v);
    v = find(v,-1,s[v]);
    blocked[v] = 1;
    prec(v);
    total = s[v];
    root = new node();
    rem[v] = k;
    ins(k);
    REV = false;
    cc[v] = 0;
    for(auto [to,w]:g[v]) if(!blocked[to]) {
        go(to,v,w);
        vector<pair<int,ll>> bf = {make_pair(v,w)};
        cl(to,v);
        pt =to;
        add(to,v,bf);
    }
    root = new node();
    reverse(g[v].begin(),g[v].end());
    REV = 1;
    for(auto [to,w]:g[v]) if(!blocked[to]) {
        go(to,v,w);
        vector<pair<int,ll>> bf = {make_pair(v,w)};
        cl(to,v);
        add(to,v,bf);
    }
    for(int to:G[v]) if(!blocked[to]) {
          decompose(to);
    }
}
void test() {
    cin >> n >> k;
    for(int i = 1; i <= n - 1; i++) {
        int a,b,c;
        cin >> a >> b >> c;
        a++;
        b++;
        g[a].push_back({b,c});G[a].push_back(b);
        g[b].push_back({a,c});G[b].push_back(a);
    }
    decompose(1);
    for(int i = 1; i <= n; i++) {
        cout << ans[i] << '\n';
    }
}
int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(0);
    int t = 1; 
    // cin >> t;
    while(t--) {
        test();
    }
}

Compilation message (stderr)

Main.cpp: In function 'void add(int, int, std::vector<std::pair<int, long long int> >&)':
Main.cpp:139:8: warning: unused variable 'c' [-Wunused-variable]
  139 |     ll c = k;
      |        ^
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...