Submission #1366528

#TimeUsernameProblemLanguageResultExecution timeMemory
1366528mariaclaraPetrol stations (CEOI24_stations)C++20
48 / 100
3583 ms19076 KiB
#include<bits/stdc++.h>

using namespace std;

typedef long long ll;
typedef pair<int,int> pii;
typedef tuple<int,int,int> trio;
typedef vector<int> vi;
const int MAXN = 7e4+5;
#define all(x) x.begin(), x.end()
#define sz(x) (int)x.size()
#define mk make_pair 
#define pb push_back
#define fr first 
#define sc second

int n, K;
ll ans[MAXN];
vector<vector<trio>> edges;

int T, allowed[MAXN], pp[MAXN], pe[MAXN], SUB[2*MAXN], tam[MAXN];
int tin[MAXN], tout[MAXN];
ll niv[MAXN];

void construct_euler(int x) {
    tin[x] = ++T;

    for(auto [viz, p, id] : edges[x]) {
        if(!allowed[viz] or viz == pp[x]) continue;

        niv[viz] = niv[x] + p;
        tam[viz] = tam[x] + 1;
        pp[viz] = x;
        construct_euler(viz);
    }

    tout[x] = T;
}

int al[MAXN], ar[MAXN];
vector<pair<ll,int>> att;

void calc_niv(int x) {
    att.pb({niv[x], x});

    if(niv[x] > K) {
        int idr = lower_bound(all(att), mk(niv[x] - K, 0)) - att.begin();
        ar[x] = att[idr].sc;
        int idl = lower_bound(all(att), mk(niv[pp[x]] - K, 0)) - att.begin();
        al[x] = att[idl+1].sc;

        if(idl+1 > idr) al[x] = ar[x] = -1;
    }

    for(auto [viz, p, id] : edges[x]) {
        if(!allowed[viz] or viz == pp[x]) continue;

        niv[viz] = niv[x] + p;
        tam[viz] = tam[x] + 1;
        pp[viz] = x;
        pe[viz] = id;
        calc_niv(viz);
    }

    att.pop_back();
}

ll rsp[MAXN], bit[MAXN];

void update(int x, ll val) {
    for(x++; x <= T+1; x += x&-x)
        bit[x] += val;
}

int query(int x) {
    int sum = 0;
    x++;

    while(x > 0) {
        sum += bit[x];
        x -= x&-x;
    }

    return sum;
}

void att_ans(vi &a, vi &b, int c) {
    for(auto it : a) allowed[it] = 1;
    niv[c] = tam[c] = 0;
    pp[c] = T = -1;

    construct_euler(c);
    
    for(auto it : a) allowed[it] = 0;

    vector<tuple<ll, int, int>> sum;
    vi ord;

    for(auto x : a) {
        if(x == c) continue;
        sum.pb({K+niv[pp[x]], x, 1});
        sum.pb({K+niv[x], x, -1});
        ord.pb(x);
    }

    // K+niv[x] >= lev > K+niv[pp[x]]

    sort(all(sum));
    reverse(all(sum));
    sort(all(ord), [](int a, int b){
        return make_tuple(niv[a], tam[a], a) > make_tuple(niv[b], tam[b], b);
    });

    for(int i = 0, j = 0; i < sz(sum); i++) {
        auto [lev, x, fl] = sum[i];

        while(j < sz(ord) and niv[ord[j]] > lev)
            update(tin[ord[j]], rsp[ord[j]]+1), j++;
    
        rsp[x] += (query(tout[x]) - query(tin[x]-1)) * fl;
    }

    for(auto it : a) if(it != c) rsp[it]++;
    for(auto it : b) allowed[it] = 1, al[it] = ar[it] = -1;
    niv[c] = tam[c] = 0;
    pp[c] = -1;

    calc_niv(c);

    sum.clear();

    for(auto x : b) {
        if(x == c) continue;
        sum.pb({K - niv[x], x, -1});
        sum.pb({K - niv[pp[x]], x, 1});
    }

    // K - niv[x] < lev <= K-niv[pp[x]]

    sort(all(sum));
    reverse(all(ord));

    fill(bit, bit+T+2, 0);

    for(int i = 0, j = 0; i < sz(sum); i++) {
        auto [lev, x, fl] = sum[i];

        while(j < sz(ord) and niv[ord[j]] <= lev) 
            update(tin[ord[j]], rsp[ord[j]]), j++;

        rsp[x] += query(T) * fl;
    }

    ord.clear();
    for(auto it : b) ord.pb(it);
    sort(all(ord), [](int a, int b){
        return mk(tam[a], a) < mk(tam[b], b);
    });

    for(int it : b) niv[it] = 0;

    for(auto x : ord) { // aqui
        if(x == c) continue;
        if(al[x] != -1) rsp[x] += niv[ar[x]] - niv[al[x]] + rsp[al[x]];
        niv[x] = niv[pp[x]] + rsp[x];
    }

    for(auto x : b)
        if(x != c) ans[pp[x]] += rsp[x] * SUB[pe[x]];

    for(auto it : b) allowed[it] = 0;
    fill(bit, bit+T+2, 0);
    for(int it : a) rsp[it] = 0;
    for(int it : b) rsp[it] = 0;
}

int blocked[MAXN], sub[MAXN];

void calc_sub(int x, int pai) {
    sub[x] = 1;

    for(auto [viz, p, id] : edges[x]) {
        if(viz == pai or blocked[viz]) continue;
        calc_sub(viz, x);
        sub[x] += sub[viz];
    }
}

int calc_centroid(int x, int tot, int pai) {
    for(auto [viz, p, id] : edges[x]) {
        if(viz == pai or blocked[viz]) continue;
        if(sub[viz]*2 > tot) return calc_centroid(viz, tot, x);
    }
    return x;
}

vi A, B;

void addA(int x) {
    A.pb(x);

    for(auto [viz, p, id] : edges[x]) {
        if(blocked[viz] or sub[viz] > sub[x]) continue;
        addA(viz);
    }
}

void addB(int x) {
    B.pb(x);

    for(auto [viz, p, id] : edges[x]) {
        if(blocked[viz] or sub[viz] > sub[x]) continue;
        addB(viz);
    }
}

void decomp(vi v) {
    if(sz(v) <= 2) return;

    for(auto it : v) blocked[it] = 0, sub[it] = 0;

    calc_sub(v[0], -1);
    int c = calc_centroid(v[0], sz(v), -1);

    for(auto it : v) sub[it] = 0;
    calc_sub(c, -1);

    A.clear(); B.clear();
    int Sa = 0, Sb = 0;
    A.pb(c);
    B.pb(c);

    vi V;

    for(auto [viz, p, aux] : edges[c]) {
        if(blocked[viz]) continue;
        V.pb(viz);
    }

    sort(all(V), [](int x, int y){
        return mk(sub[x], x) > mk(sub[y], y);
    });

    for(auto viz : V) {
        if(abs(Sa+sub[viz] - Sb) <= abs(Sb+sub[viz] - Sa)) 
            addA(viz), Sa += sub[viz];
        else addB(viz), Sb += sub[viz];
    }

    assert(max(Sa,Sb) <= 2*min(Sa, Sb) + 50);

    for(auto it : v) blocked[it] = 1, sub[it] = 0;

    att_ans(A, B, c);
    att_ans(B, A, c);

    vi aux = B;
    decomp(A);
    decomp(aux);
}

void pre_calc(int x, int pai) {
    sub[x] = 1;

    for(auto [viz, p, id] : edges[x]) {
        if(viz == pai) continue;

        pre_calc(viz, x);
        sub[x] += sub[viz];

        SUB[id] = sub[viz];
        int ot = id%2 ? id+1 : id-1;
        SUB[ot] = n-sub[viz];
    }
}

int32_t main() {
    ios_base::sync_with_stdio(0);
    cin.tie(0);

    cin >> n >> K;

    vi v = {n-1};
    edges.resize(n);

    for(int i = 1, a, b, c; i < n; i++) {
        cin >> a >> b >> c;
        edges[a].pb({b, c, 2*i-1});
        edges[b].pb({a, c, 2*i});
        v.pb(i-1);
    }

    pre_calc(0, 0);
    decomp(v);

    for(int i = 0; i < n; i++)
        cout << ans[i] << "\n";
}
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...