Submission #480273

#TimeUsernameProblemLanguageResultExecution timeMemory
480273Jarif_RahmanJanjetina (COCI21_janjetina)C++17
110 / 110
390 ms17544 KiB
#include <bits/stdc++.h> #define pb push_back #define f first #define sc second using namespace std; typedef long long int ll; typedef string str; struct BIT{ int n; vector<ll> sm; BIT(int _n){ n = _n; sm.resize(n); } void add(int in, int x){ in++; while(in <= n) sm[in-1]+=x, in+=in&-in; } ll sum(int in){ in++; ll s = 0; while(in >= 1) s+=sm[in-1], in-=in&-in; return s; } ll sum(int l, int r){ return sum(r)-(l == 0? 0:sum(l-1)); } }; int n, k; vector<vector<pair<int, int>>> v; vector<int> sz; void dfs0(int nd, int ss){ for(auto [x, w]: v[nd]) if(x != ss) dfs0(x, nd); for(auto [x, w]: v[nd]) if(x != ss) sz[nd]+=sz[x]; } int dfs1(int nd, int ss, int sz0){ for(auto [x, w]: v[nd]) if(x != ss){ int rt = dfs1(x, nd, sz0); if(rt != -1) return rt; } bool bl = 1; for(auto [x, w]: v[nd]) if(x != ss) if(sz[x] > sz0/2) bl = 0; if(sz0-sz[nd] > sz0/2) bl = 0; return bl ? nd : -1; } vector<pair<int, int>> sth; ll cnt(){ sort(sth.rbegin(), sth.rend()); BIT bit(sth.size()+3); ll cur = 0; for(int i = (int)sth.size() - 1; i >= 0; i--){ int lim = sth[i].f-sth[i].sc-k; lim = min(lim, (int)sth.size() + 2); if(lim >= 0) cur+=bit.sum(0, lim); bit.add(sth[i].sc, 1); } return cur; } void dfs2(int nd, int ss, int dis, int mx){ sth.pb({mx, dis}); for(auto [x, w]: v[nd]) if(x != ss) dfs2(x, nd, dis+1, max(mx, w)); } int main(){ ios_base::sync_with_stdio(0); cin.tie(0); cin >> n >> k; v.resize(n); for(int i = 0; i < n-1; i++){ int a, b, w; cin >> a >> b >> w; a--, b--; v[a].pb({b, w}); v[b].pb({a, w}); } stack<int> st; st.push(0); ll ans = 0; while(!st.empty()){ stack<int> st0; sz.assign(n, 1); while(!st.empty()){ int nd = st.top(); st.pop(); sth.clear(); dfs2(nd, -1, 0, 0); ans+=cnt(); for(auto [x, w]: v[nd]){ sth.clear(); dfs2(x, nd, 1, w); ans-=cnt(); } for(auto [x, w]: v[nd]){ v[x].erase(find(v[x].begin(), v[x].end(), make_pair(nd, w))); dfs0(x, -1); st0.push(dfs1(x, -1, sz[x])); } } swap(st, st0); } ans*=2LL; cout << ans << "\n"; }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...