This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include <bits/stdc++.h>
#define debug printf
#define ff first
#define ss second
#define mkt make_tuple
#define all(x) x.begin(),x.end()
#define sz(x) (int)(x.size())
#define ll long long
#define lp(i,a,b) for(int i = a ; i < b ; i++ )
#define pii pair<int,int>
#define mk make_pair
#define pb push_back
#define tiiii tuple<int,int,int, int>
const int MAXN = 1e5+10 ;
using namespace std ;
ll ans ;
struct Query
{
vector<int> procuro ;
int other_vert , edge_weight ;
bool operator < ( Query other ) const { return edge_weight < other.edge_weight ; }
} ;
int N, K ;
int dsu[MAXN] , qtd[MAXN] , weight[MAXN] ;
vector<int> adj_dsu[MAXN] ;
vector<pair<int,int> > adj[MAXN] ;
vector< tuple<int,int,int, int> > edges ;
vector<Query> queries[MAXN] ;
bool marc[MAXN] ;
// ----------------------- UNION FIND -------------------------
int find(int x) { return dsu[x] = (x == dsu[x]) ? x : find(dsu[x]) ; }
int w_union_find ;
void dfs_union_find(int x, int y , int par, int depth )
{
if(w_union_find-K-depth-1 >= 0)
queries[y].back().procuro.pb( min(w_union_find-K-depth-1, N) ) ;
for(auto e : adj_dsu[x] )
{
if(e == par ) continue ;
dfs_union_find(e, y, x, depth+1) ;
}
}
void join(int x, int y, int w, int i)
{
int a = find(x) ;
int b = find(y) ;
w_union_find = w ;
if( qtd[a] > qtd[b] )
{
swap(a,b) ;
swap(x,y) ;
}
queries[y].pb( Query() ) ;
queries[y].back().edge_weight = i-1 ;
queries[y].back().other_vert = x ;
dfs_union_find(x, y, -1, 0 );
adj_dsu[x].pb(y) ;
adj_dsu[y].pb(x) ;
dsu[a] = b ;
qtd[b] += qtd[a] ;
}
// ------------------------------------------------------------
int q ;
int sub[MAXN] , bit[MAXN] , dist[MAXN] ;
vector<pair<int,int> > pares , perguntas ;
void upd(int i , int val) { if(i == -1 ) return ;for(++i; i < MAXN ; i += i &-i ) bit[i] += val; }
int qry(int i)
{
int tot = 0 ;
for(++i; i > 0 ; i -= i &-i ) tot += bit[i] ;
return tot ;
}
void dfs1(int x, int par )
{
sub[x] = 1 ;
for(auto e : adj[x])
if(e.ff != par && !marc[e.ff])
{
dfs1(e.ff,x) ;
sub[x] += sub[e.ff] ;
}
}
int dfs2(int x, int par )
{
for(auto e : adj[x] )
{
if(e.ff == par || marc[e.ff]) continue ;
if( sub[e.ff] > q/2 ) return dfs2(e.ff, x ) ;
}
return x ;
}
void dfs3(int x, int par , int cur_mx , int cur_depth )
{
dist[x] = cur_depth ;
pares.pb( mk(cur_mx, cur_depth) ) ;
for(int i = 0 ; i < sz(queries[x]) ; i++ )
{
if( queries[x][i].other_vert == par || cur_mx > queries[x][i].edge_weight ) continue ;
perguntas.pb(mk(x,i)) ;
}
for(auto e : adj[x] )
{
if(marc[e.ff] || e.ff == par ) continue ;
dfs3(e.ff, x, max(cur_mx, weight[e.ss]) , cur_depth+1 ) ;
}
}
void solve(int ptr1, int ptr2, int s)
{
while(ptr1 < sz(pares) || ptr2 < sz(perguntas))
{
if( ptr2 == sz(perguntas) )
{
upd(pares[ptr1].ss, 1) ;
ptr1++ ;
continue ;
}
int x = perguntas[ptr2].ff ;
int y = perguntas[ptr2].ss ;
int k = queries[ x ][y].edge_weight ;
if(ptr1 == sz(pares) || pares[ptr1].ff > k)
{
for( auto e : queries[x][y].procuro )
ans += qry( e - dist[x]) * s ;
ptr2++ ;
continue ;
}
else
{
upd(pares[ptr1].ss , 1 ) ;
ptr1++ ;
continue ;
}
}
}
bool cmp(pii p1, pii p2) { return queries[p1.ff][p1.ss] < queries[p2.ff][p2.ss] ; }
void decompose(int x)
{
dfs1(x,-1) ;
q = sub[x] ;
int cn = dfs2(x,-1) ;
marc[cn] = true ;
pares.clear() ; perguntas.clear() ;
pares.pb( mk( -1, 0 ) ) ;
for(auto e : adj[cn])
{
if(marc[e.ff]) continue ;
int l_pares= sz(pares) ;
int l_perguntas = sz(perguntas) ;
dfs3(e.ff, cn, weight[e.ss] , 1 ) ;
sort( pares.begin()+l_pares , pares.end() ) ;
sort( perguntas.begin()+l_perguntas, perguntas.end() , cmp ) ;
int ptr1 = l_pares , ptr2 = l_perguntas ;
solve(ptr1, ptr2,-1) ;
for(int i = l_pares ; i < sz(pares) ; i++ )
upd( pares[i].ss , -1 ) ;
}
sort(all(pares)) ;
sort(all(perguntas),cmp) ;
solve(0,0,1) ;
for(int i = 0 ; i < sz(pares) ; i++ )
upd( pares[i].ss , -1 ) ;
perguntas.clear() ;
for(int i = 0 ; i < sz(queries[cn]) ; i++ )
{
perguntas.pb(mk(cn,i)) ;
queries[cn][i].edge_weight-- ;
}
dist[cn] = 0 ;
solve(0,0,1) ;
for(int i = 0 ; i < sz(pares) ; i++ )
upd( pares[i].ss , -1 ) ;
for(auto e : adj[cn])
if(!marc[e.ff]) decompose(e.ff) ;
}
int main()
{
scanf("%d %d", &N, &K ) ;
for(int i = 1 ; i <= N ; i++ )
{
dsu[i] = i ;
qtd[i] = 1 ;
}
for(int i = 0, x , y , w ; i < N-1 ; i++ )
{
scanf("%d %d %d", &x, &y, &w ) ;
adj[x].pb(mk(y,i)) ;
adj[y].pb(mk(x,i) );
edges.pb(mkt(w, x, y,i)) ;
}
sort(all(edges)) ;
for(int i = 0 , x , y , w ; i < N-1; i++ )
{
w = get<0>(edges[i]) ;
x = get<1>(edges[i]) ;
y = get<2>(edges[i]) ;
join(x,y,w, (i+1)*2 ) ;
weight[get<3>(edges[i])] = (i+1)*2 ;
}
decompose(1) ;
printf("%lld\n" , ans*2LL ) ;
}
Compilation message (stderr)
Main.cpp: In function 'int main()':
Main.cpp:231:7: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
231 | scanf("%d %d", &N, &K ) ;
| ~~~~~^~~~~~~~~~~~~~~~~~
Main.cpp:241:8: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
241 | scanf("%d %d %d", &x, &y, &w ) ;
| ~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |