Submission #475755

#TimeUsernameProblemLanguageResultExecution timeMemory
475755CaroLindaJanjetina (COCI21_janjetina)C++14
110 / 110
1369 ms36736 KiB
#include <bits/stdc++.h> #define debug printf #define ff first #define ss second #define mkt make_tuple #define all(x) x.begin(),x.end() #define sz(x) (int)(x.size()) #define ll long long #define lp(i,a,b) for(int i = a ; i < b ; i++ ) #define pii pair<int,int> #define mk make_pair #define pb push_back #define tiiii tuple<int,int,int, int> const int MAXN = 1e5+10 ; using namespace std ; ll ans ; struct Query { vector<int> procuro ; int other_vert , edge_weight ; bool operator < ( Query other ) const { return edge_weight < other.edge_weight ; } } ; int N, K ; int dsu[MAXN] , qtd[MAXN] , weight[MAXN] ; vector<int> adj_dsu[MAXN] ; vector<pair<int,int> > adj[MAXN] ; vector< tuple<int,int,int, int> > edges ; vector<Query> queries[MAXN] ; bool marc[MAXN] ; // ----------------------- UNION FIND ------------------------- int find(int x) { return dsu[x] = (x == dsu[x]) ? x : find(dsu[x]) ; } int w_union_find ; void dfs_union_find(int x, int y , int par, int depth ) { if(w_union_find-K-depth-1 >= 0) queries[y].back().procuro.pb( min(w_union_find-K-depth-1, N) ) ; for(auto e : adj_dsu[x] ) { if(e == par ) continue ; dfs_union_find(e, y, x, depth+1) ; } } void join(int x, int y, int w, int i) { int a = find(x) ; int b = find(y) ; w_union_find = w ; if( qtd[a] > qtd[b] ) { swap(a,b) ; swap(x,y) ; } queries[y].pb( Query() ) ; queries[y].back().edge_weight = i-1 ; queries[y].back().other_vert = x ; dfs_union_find(x, y, -1, 0 ); adj_dsu[x].pb(y) ; adj_dsu[y].pb(x) ; dsu[a] = b ; qtd[b] += qtd[a] ; } // ------------------------------------------------------------ int q ; int sub[MAXN] , bit[MAXN] , dist[MAXN] ; vector<pair<int,int> > pares , perguntas ; void upd(int i , int val) { if(i == -1 ) return ;for(++i; i < MAXN ; i += i &-i ) bit[i] += val; } int qry(int i) { int tot = 0 ; for(++i; i > 0 ; i -= i &-i ) tot += bit[i] ; return tot ; } void dfs1(int x, int par ) { sub[x] = 1 ; for(auto e : adj[x]) if(e.ff != par && !marc[e.ff]) { dfs1(e.ff,x) ; sub[x] += sub[e.ff] ; } } int dfs2(int x, int par ) { for(auto e : adj[x] ) { if(e.ff == par || marc[e.ff]) continue ; if( sub[e.ff] > q/2 ) return dfs2(e.ff, x ) ; } return x ; } void dfs3(int x, int par , int cur_mx , int cur_depth ) { dist[x] = cur_depth ; pares.pb( mk(cur_mx, cur_depth) ) ; for(int i = 0 ; i < sz(queries[x]) ; i++ ) { if( queries[x][i].other_vert == par || cur_mx > queries[x][i].edge_weight ) continue ; perguntas.pb(mk(x,i)) ; } for(auto e : adj[x] ) { if(marc[e.ff] || e.ff == par ) continue ; dfs3(e.ff, x, max(cur_mx, weight[e.ss]) , cur_depth+1 ) ; } } void solve(int ptr1, int ptr2, int s) { while(ptr1 < sz(pares) || ptr2 < sz(perguntas)) { if( ptr2 == sz(perguntas) ) { upd(pares[ptr1].ss, 1) ; ptr1++ ; continue ; } int x = perguntas[ptr2].ff ; int y = perguntas[ptr2].ss ; int k = queries[ x ][y].edge_weight ; if(ptr1 == sz(pares) || pares[ptr1].ff > k) { for( auto e : queries[x][y].procuro ) ans += qry( e - dist[x]) * s ; ptr2++ ; continue ; } else { upd(pares[ptr1].ss , 1 ) ; ptr1++ ; continue ; } } } bool cmp(pii p1, pii p2) { return queries[p1.ff][p1.ss] < queries[p2.ff][p2.ss] ; } void decompose(int x) { dfs1(x,-1) ; q = sub[x] ; int cn = dfs2(x,-1) ; marc[cn] = true ; pares.clear() ; perguntas.clear() ; pares.pb( mk( -1, 0 ) ) ; for(auto e : adj[cn]) { if(marc[e.ff]) continue ; int l_pares= sz(pares) ; int l_perguntas = sz(perguntas) ; dfs3(e.ff, cn, weight[e.ss] , 1 ) ; sort( pares.begin()+l_pares , pares.end() ) ; sort( perguntas.begin()+l_perguntas, perguntas.end() , cmp ) ; int ptr1 = l_pares , ptr2 = l_perguntas ; solve(ptr1, ptr2,-1) ; for(int i = l_pares ; i < sz(pares) ; i++ ) upd( pares[i].ss , -1 ) ; } sort(all(pares)) ; sort(all(perguntas),cmp) ; solve(0,0,1) ; for(int i = 0 ; i < sz(pares) ; i++ ) upd( pares[i].ss , -1 ) ; perguntas.clear() ; for(int i = 0 ; i < sz(queries[cn]) ; i++ ) { perguntas.pb(mk(cn,i)) ; queries[cn][i].edge_weight-- ; } dist[cn] = 0 ; solve(0,0,1) ; for(int i = 0 ; i < sz(pares) ; i++ ) upd( pares[i].ss , -1 ) ; for(auto e : adj[cn]) if(!marc[e.ff]) decompose(e.ff) ; } int main() { scanf("%d %d", &N, &K ) ; for(int i = 1 ; i <= N ; i++ ) { dsu[i] = i ; qtd[i] = 1 ; } for(int i = 0, x , y , w ; i < N-1 ; i++ ) { scanf("%d %d %d", &x, &y, &w ) ; adj[x].pb(mk(y,i)) ; adj[y].pb(mk(x,i) ); edges.pb(mkt(w, x, y,i)) ; } sort(all(edges)) ; for(int i = 0 , x , y , w ; i < N-1; i++ ) { w = get<0>(edges[i]) ; x = get<1>(edges[i]) ; y = get<2>(edges[i]) ; join(x,y,w, (i+1)*2 ) ; weight[get<3>(edges[i])] = (i+1)*2 ; } decompose(1) ; printf("%lld\n" , ans*2LL ) ; }

Compilation message (stderr)

Main.cpp: In function 'int main()':
Main.cpp:231:7: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
  231 |  scanf("%d %d", &N, &K ) ;
      |  ~~~~~^~~~~~~~~~~~~~~~~~
Main.cpp:241:8: warning: ignoring return value of 'int scanf(const char*, ...)' declared with attribute 'warn_unused_result' [-Wunused-result]
  241 |   scanf("%d %d %d", &x, &y, &w ) ;
      |   ~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...