제출 #480273

#제출 시각아이디문제언어결과실행 시간메모리
480273Jarif_RahmanJanjetina (COCI21_janjetina)C++17
110 / 110
390 ms17544 KiB
#include <bits/stdc++.h>
#define pb push_back
#define f first
#define sc second
using namespace std;
typedef long long int ll;
typedef string str;

struct BIT{
    int n;
    vector<ll> sm;
    BIT(int _n){
        n = _n;
        sm.resize(n);
    }
    void add(int in, int x){
        in++;
        while(in <= n) sm[in-1]+=x, in+=in&-in;
    }
    ll sum(int in){
        in++;
        ll s = 0;
        while(in >= 1) s+=sm[in-1], in-=in&-in;
        return s;
    }
    ll sum(int l, int r){
        return sum(r)-(l == 0? 0:sum(l-1));
    }
};

int n, k;
vector<vector<pair<int, int>>> v;
vector<int> sz;

void dfs0(int nd, int ss){
    for(auto [x, w]: v[nd]) if(x != ss) dfs0(x, nd);
    for(auto [x, w]: v[nd]) if(x != ss) sz[nd]+=sz[x];
}

int dfs1(int nd, int ss, int sz0){
    for(auto [x, w]: v[nd]) if(x != ss){
        int rt = dfs1(x, nd, sz0);
        if(rt != -1) return rt;
    }

    bool bl = 1;
    for(auto [x, w]: v[nd]) if(x != ss) if(sz[x] > sz0/2) bl = 0;
    if(sz0-sz[nd] > sz0/2) bl = 0;
    
    return bl ? nd : -1;
}

vector<pair<int, int>> sth;

ll cnt(){
    sort(sth.rbegin(), sth.rend());
    BIT bit(sth.size()+3);
    ll cur = 0;
    for(int i = (int)sth.size() - 1; i >= 0; i--){
        int lim = sth[i].f-sth[i].sc-k;
        lim = min(lim, (int)sth.size() + 2);
        if(lim >= 0) cur+=bit.sum(0, lim);
        bit.add(sth[i].sc, 1);
    }
    return cur;
}

void dfs2(int nd, int ss, int dis, int mx){
    sth.pb({mx, dis});
    for(auto [x, w]: v[nd]) if(x != ss) dfs2(x, nd, dis+1, max(mx, w));
}

int main(){
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cin >> n >> k;
    v.resize(n);
    for(int i = 0; i < n-1; i++){
        int a, b, w; cin >> a >> b >> w; a--, b--;
        v[a].pb({b, w});
        v[b].pb({a, w});
    }

    stack<int> st;
    st.push(0);

    ll ans = 0;

    while(!st.empty()){
        stack<int> st0;
        sz.assign(n, 1);
        while(!st.empty()){
            int nd = st.top();
            st.pop();
            sth.clear();
            dfs2(nd, -1, 0, 0);
            ans+=cnt();
            for(auto [x, w]: v[nd]){
                sth.clear();
                dfs2(x, nd, 1, w);
                ans-=cnt();
            }

            for(auto [x, w]: v[nd]){
                v[x].erase(find(v[x].begin(), v[x].end(), make_pair(nd, w)));
                dfs0(x, -1);
                st0.push(dfs1(x, -1, sz[x]));
            }
        }
        swap(st, st0);
    }

    ans*=2LL;

    cout << ans << "\n";
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...