제출 #475755

#제출 시각아이디문제언어결과실행 시간메모리
475755CaroLindaJanjetina (COCI21_janjetina)C++14
110 / 110
1369 ms36736 KiB
#include <bits/stdc++.h>

#define debug printf
#define ff first
#define ss second
#define mkt make_tuple
#define all(x) x.begin(),x.end()
#define sz(x) (int)(x.size())
#define ll long long
#define lp(i,a,b) for(int i = a ; i < b ; i++ )
#define pii pair<int,int>
#define mk make_pair
#define pb push_back
#define tiiii tuple<int,int,int, int>

const int MAXN = 1e5+10 ;

using namespace std ;

ll ans ;

struct Query
{

	vector<int> procuro ;
	int other_vert , edge_weight ;

	bool operator < ( Query other ) const { return edge_weight < other.edge_weight ; }

} ;

int N, K ;
int dsu[MAXN] , qtd[MAXN] , weight[MAXN] ;
vector<int> adj_dsu[MAXN] ;
vector<pair<int,int> > adj[MAXN] ;
vector< tuple<int,int,int, int> > edges ;
vector<Query> queries[MAXN] ;
bool marc[MAXN] ;

// ----------------------- UNION FIND -------------------------
int find(int x) { return dsu[x] = (x == dsu[x]) ? x : find(dsu[x]) ; }
int w_union_find ;
void dfs_union_find(int x, int y , int par, int depth )
{
	if(w_union_find-K-depth-1 >= 0)
		queries[y].back().procuro.pb( min(w_union_find-K-depth-1, N) ) ;

	for(auto e : adj_dsu[x] )
	{
		if(e == par ) continue ;
		dfs_union_find(e, y, x, depth+1) ;
	}
}
void join(int x, int y, int w, int i)
{
	int a = find(x) ;
	int b = find(y) ;
	w_union_find = w ;

	if( qtd[a] > qtd[b] )
	{
		swap(a,b) ;
		swap(x,y) ;
	}

	queries[y].pb( Query() ) ;
	queries[y].back().edge_weight = i-1 ;
	queries[y].back().other_vert = x ;

	dfs_union_find(x, y, -1, 0 ); 

	adj_dsu[x].pb(y) ;
	adj_dsu[y].pb(x) ;
	dsu[a] = b ;
	qtd[b] += qtd[a] ;

}
// ------------------------------------------------------------

int q ;
int sub[MAXN] , bit[MAXN] , dist[MAXN] ;
vector<pair<int,int> > pares , perguntas ;

void upd(int i , int val) { if(i == -1 ) return ;for(++i; i < MAXN ; i += i &-i ) bit[i] += val; }
int qry(int i)
{
	int tot = 0 ;
	for(++i; i > 0 ; i -= i &-i ) tot += bit[i] ;
	return tot ;
}

void dfs1(int x, int par )
{
	sub[x] = 1 ;
	for(auto e : adj[x])
		if(e.ff != par && !marc[e.ff])
		{
			dfs1(e.ff,x) ;
			sub[x] += sub[e.ff] ;
		}
}
int dfs2(int x, int par )
{
	for(auto e : adj[x] ) 
	{
		if(e.ff == par || marc[e.ff]) continue ;
		if( sub[e.ff] > q/2 ) return dfs2(e.ff, x ) ;
	}
	return x ;
}

void dfs3(int x, int par , int cur_mx , int cur_depth )
{

	dist[x] = cur_depth ;
	pares.pb( mk(cur_mx, cur_depth) ) ;

	for(int i = 0 ; i < sz(queries[x]) ; i++ )
	{
		if( queries[x][i].other_vert == par || cur_mx > queries[x][i].edge_weight ) continue ;
		perguntas.pb(mk(x,i)) ;
	}

	for(auto e : adj[x] ) 
	{
		if(marc[e.ff] || e.ff == par ) continue ;
		dfs3(e.ff, x, max(cur_mx, weight[e.ss]) , cur_depth+1 ) ;
	}

}

void solve(int ptr1, int ptr2, int s)
{
	while(ptr1 < sz(pares) || ptr2 < sz(perguntas))
	{
		if( ptr2 == sz(perguntas) ) 
		{
			upd(pares[ptr1].ss, 1) ;
			ptr1++ ;
			continue ;
		}

		int x = perguntas[ptr2].ff ;
		int y = perguntas[ptr2].ss ;
		int k = queries[ x ][y].edge_weight ;

		if(ptr1 == sz(pares) || pares[ptr1].ff > k)
		{
			for( auto e : queries[x][y].procuro )
				ans += qry( e - dist[x]) * s ;

			ptr2++ ;
			continue ;
		}
		else
		{
			upd(pares[ptr1].ss , 1 ) ;
			ptr1++ ;
			continue ;
		}

	}
}

bool cmp(pii p1, pii p2) { return queries[p1.ff][p1.ss] < queries[p2.ff][p2.ss] ; } 

void decompose(int x)
{

	dfs1(x,-1) ;
	q = sub[x] ;
	
	int cn = dfs2(x,-1) ;

	marc[cn] = true ;

	pares.clear() ; perguntas.clear() ;
	pares.pb( mk( -1, 0 ) ) ;

	for(auto e : adj[cn])
	{
		if(marc[e.ff]) continue ;

		int l_pares=  sz(pares) ;
		int l_perguntas = sz(perguntas) ;

		dfs3(e.ff, cn, weight[e.ss] , 1 ) ;

		sort( pares.begin()+l_pares , pares.end() ) ;
		sort( perguntas.begin()+l_perguntas, perguntas.end() , cmp ) ;

		int ptr1 = l_pares , ptr2 = l_perguntas ;

		solve(ptr1, ptr2,-1) ;

		for(int i = l_pares ; i < sz(pares) ; i++ ) 
			upd( pares[i].ss , -1 ) ;

	}

	sort(all(pares)) ;
	sort(all(perguntas),cmp) ;

	solve(0,0,1) ;

	for(int i = 0 ; i < sz(pares) ; i++ ) 
		upd( pares[i].ss , -1 ) ;	

	perguntas.clear() ;
	for(int i = 0 ; i < sz(queries[cn]) ; i++ ) 
	{
		perguntas.pb(mk(cn,i)) ;
		queries[cn][i].edge_weight-- ;
	}


	dist[cn] = 0 ;

	solve(0,0,1)  ;

	for(int i = 0 ; i < sz(pares) ; i++ ) 
		upd( pares[i].ss , -1 ) ;

	for(auto e : adj[cn])
		if(!marc[e.ff]) decompose(e.ff) ;	

}

int main()
{
	scanf("%d %d", &N, &K ) ;

	for(int i = 1 ; i <= N ; i++ )
	{
		dsu[i] = i ;
		qtd[i] = 1 ;
	}

	for(int i = 0, x , y , w ; i < N-1 ; i++ )
	{
		scanf("%d %d %d", &x, &y, &w ) ;
		adj[x].pb(mk(y,i)) ;
		adj[y].pb(mk(x,i) );
		edges.pb(mkt(w, x, y,i)) ;
	}

	sort(all(edges)) ;
	for(int i = 0 , x , y , w ; i < N-1; i++ )
	{
		w = get<0>(edges[i]) ;
		x = get<1>(edges[i]) ;
		y = get<2>(edges[i]) ;

		join(x,y,w, (i+1)*2 ) ;

		weight[get<3>(edges[i])] = (i+1)*2 ;
	}

	decompose(1) ;

	printf("%lld\n" , ans*2LL ) ;

}

컴파일 시 표준 에러 (stderr) 메시지

Main.cpp: In function 'int main()':
Main.cpp:231:7: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
  231 |  scanf("%d %d", &N, &K ) ;
      |  ~~~~~^~~~~~~~~~~~~~~~~~
Main.cpp:241:8: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
  241 |   scanf("%d %d %d", &x, &y, &w ) ;
      |   ~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...