This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include <bits/stdc++.h>
#define pb push_back
#define f first
#define sc second
using namespace std;
typedef long long int ll;
typedef string str;
struct BIT{
int n;
vector<ll> sm;
BIT(int _n){
n = _n;
sm.resize(n);
}
void add(int in, int x){
in++;
while(in <= n) sm[in-1]+=x, in+=in&-in;
}
ll sum(int in){
in++;
ll s = 0;
while(in >= 1) s+=sm[in-1], in-=in&-in;
return s;
}
ll sum(int l, int r){
return sum(r)-(l == 0? 0:sum(l-1));
}
};
int n, k;
vector<vector<pair<int, int>>> v;
vector<int> sz;
void dfs0(int nd, int ss){
for(auto [x, w]: v[nd]) if(x != ss) dfs0(x, nd);
for(auto [x, w]: v[nd]) if(x != ss) sz[nd]+=sz[x];
}
int dfs1(int nd, int ss, int sz0){
for(auto [x, w]: v[nd]) if(x != ss){
int rt = dfs1(x, nd, sz0);
if(rt != -1) return rt;
}
bool bl = 1;
for(auto [x, w]: v[nd]) if(x != ss) if(sz[x] > sz0/2) bl = 0;
if(sz0-sz[nd] > sz0/2) bl = 0;
return bl ? nd : -1;
}
vector<pair<int, int>> sth;
ll cnt(){
sort(sth.rbegin(), sth.rend());
BIT bit(sth.size()+3);
ll cur = 0;
for(int i = (int)sth.size() - 1; i >= 0; i--){
int lim = sth[i].f-sth[i].sc-k;
lim = min(lim, (int)sth.size() + 2);
if(lim >= 0) cur+=bit.sum(0, lim);
bit.add(sth[i].sc, 1);
}
return cur;
}
void dfs2(int nd, int ss, int dis, int mx){
sth.pb({mx, dis});
for(auto [x, w]: v[nd]) if(x != ss) dfs2(x, nd, dis+1, max(mx, w));
}
int main(){
ios_base::sync_with_stdio(0);
cin.tie(0);
cin >> n >> k;
v.resize(n);
for(int i = 0; i < n-1; i++){
int a, b, w; cin >> a >> b >> w; a--, b--;
v[a].pb({b, w});
v[b].pb({a, w});
}
stack<int> st;
st.push(0);
ll ans = 0;
while(!st.empty()){
stack<int> st0;
sz.assign(n, 1);
while(!st.empty()){
int nd = st.top();
st.pop();
sth.clear();
dfs2(nd, -1, 0, 0);
ans+=cnt();
for(auto [x, w]: v[nd]){
sth.clear();
dfs2(x, nd, 1, w);
ans-=cnt();
}
for(auto [x, w]: v[nd]){
v[x].erase(find(v[x].begin(), v[x].end(), make_pair(nd, w)));
dfs0(x, -1);
st0.push(dfs1(x, -1, sz[x]));
}
}
swap(st, st0);
}
ans*=2LL;
cout << ans << "\n";
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |