Submission #1277284

#TimeUsernameProblemLanguageResultExecution timeMemory
1277284julia_08Hard route (IZhO17_road)C++20
0 / 100
3 ms572 KiB
#include <bits/stdc++.h>
using namespace std;

using ll = long long;

const int MAXN = 5e5 + 10;

int pai[MAXN];

pair<ll, ll> dp_1[MAXN][3], dp_2[MAXN];
// {max_dist, cnt_dist}

vector<int> adj[MAXN];

void dfs_1(int v, int p){

  // dfs_1 calcula os tres maiores dentro da subarvore do v

  // dp_1 -> dentro da subarvore
  // dp_1[0] >= dp_1[1] >= dp_1[2]

  dp_1[v][0] = dp_1[v][1] = dp_1[v][2] = {-1e9, 0};

  if(adj[v].size() == 1) dp_1[v][0] = {0, 1};

  for(auto u : adj[v]){
    if(u != p){
  
      pai[u] = v;
      dfs_1(u, v);

      pair<ll, ll> cur = dp_1[u][0];

      for(int i=1; i<3; i++){
        if(dp_1[u][i].first == dp_1[u][0].first){
          cur.second += dp_1[u][i].second;
        }
      }

      cur.first ++;

      if(cur.first >= dp_1[v][0].first){
        
        // {cur, dp_1[v][0], dp_1[v][1]}

        dp_1[v][2] = dp_1[v][1];
        dp_1[v][1] = dp_1[v][0];
        dp_1[v][0] = cur;

      } else if(cur.first >= dp_1[v][1].first){

        // {dp_1[v][0], cur, dp_1[v][1]}
        
        dp_1[v][2] = dp_1[v][1];
        dp_1[v][1] = cur;

      } else dp_1[v][2] = max(dp_1[v][2], cur);

    }
  }

  // cout << v << " -> \n";
  // cout << "(" << dp_1[v][0].first << ", " << dp_1[v][0].second << ")\n";
  // cout << "(" << dp_1[v][1].first << ", " << dp_1[v][1].second << ")\n";
  // cout << "(" << dp_1[v][2].first << ", " << dp_1[v][2].second << ")\n";

}

void dfs_2(int v, int p){

  // dfs_2 calcula o mais distante fora da subarvore do v

  // cout << v << " -> (" << dp_2[v].first << ", " << dp_2[v].second << ")\n";

  map<ll, ll> freq;

  for(auto u : adj[v]){
    if(u != p){

      for(int i=0; i<3; i++){
        if(dp_1[u][i].first == dp_1[u][0].first){
          freq[dp_1[u][0].first] += dp_1[u][i].second;
        }
      }

    }
  }

  for(auto u : adj[v]){
    if(u != p){

      for(int i=0; i<3; i++){
        if(dp_1[u][i].first == dp_1[u][0].first){
          freq[dp_1[u][0].first] -= dp_1[u][i].second;
        }
      }

      if(freq[dp_1[u][0].first] == 0) freq.erase(dp_1[u][0].first);

      dp_2[u] = {dp_2[v].first + 1, dp_2[v].second};

      if(!freq.empty()){

        auto [d, cnt] = *freq.rbegin();

        if(d + 2 > dp_2[u].first){  
          dp_2[u] = {d + 2, cnt};
        } else if(d + 2 == dp_2[u].first) dp_2[u].second += cnt;

      }

      for(int i=0; i<3; i++){
        if(dp_1[u][i].first == dp_1[u][0].first){
          freq[dp_1[u][0].first] += dp_1[u][i].second;
        }
      }

    }
  }

  freq.clear();

  for(auto u : adj[v]) if(u != p) dfs_2(u, v);

}

int32_t main(){
  cin.tie(0)->sync_with_stdio(0);

  int n; cin >> n;

  for(int i=1; i<n; i++){
    int a, b; cin >> a >> b;
    adj[a].push_back(b);
    adj[b].push_back(a);
  }

  // cout << "checking dfs_1\n";
  dfs_1(1, 1); // ok!

  dp_2[1] = {0, 1};

  // cout << "checking dfs_2\n";
  dfs_2(1, 1); // ok!

  pair<ll, ll> ans = {0, 0};

  for(int i=1; i<=n; i++){

    vector<pair<ll, ll>> d = {dp_1[i][0], dp_1[i][1], dp_1[i][2], dp_2[i]};

    sort(d.rbegin(), d.rend());

    ll a1 = d[0].first, b1 = d[1].first, c1 = d[2].first;
    ll a2 = d[0].second, b2 = d[1].second, c2 = 0;
    
    if(dp_2[i].first == c1) c2 += dp_2[i].second;
    
    for(auto j : adj[i]){
      if(dp_1[j][0].first + 1 == c1){
        if(j != pai[i]) c2 += dp_1[j][0].second;
      }
    }
    
    if(a1 == c1) c2 -= a2;
    if(b1 == c1) c2 -= b2;

    pair<ll, ll> cur = {0, 0};

    cur.first = (b1 + c1) * a1;
    
    if(b1 == c1){
      
      for(auto j : adj[i]){
        if(dp_1[j][0].first + 1 == c1){
          if(j != pai[i]){
            cur.second += dp_1[j][0].second * (c2 - dp_1[j][0].second);
          }
        }
      }
      
      if(dp_2[i].first + 1 == c1) cur.second += dp_2[i].second * (c2 - dp_2[i].second);
      
      if(a1 == c1) cur.second -= a2 * (c2 - a2);
      if(b1 == c1) cur.second -= b2 * (c2 - b2);
      
    }

    if(a1 == b1 && b1 == c1){

      // todo mundo igual 
      cur.second += a2 * b2 + b2 * c2 + c2 * a2;

    } else if(a1 == b1){

      // dois primeiros 
      cur.second += a2 * c2 + b2 * c2;

    } else if(b1 == c1){

      // dois ultimos
      cur.second += b2 * c2;
      
    } else{
      
      // ninguem
      cur.second += b2 * c2;
      
    }
  
    if(cur.first > ans.first){
      ans = cur;
    } else if(cur.first == ans.first) ans.second += cur.second;

  }

  cout << ans.first << " " << ans.second << "\n";

  return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...