#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int MAXN = 5e5 + 10;
pair<ll, ll> dp_1[MAXN][3], dp_2[MAXN];
// {max_dist, cnt_dist}
vector<int> adj[MAXN];
void dfs_1(int v, int p){
// dfs_1 calcula os tres maiores dentro da subarvore do v
// dp_1 -> dentro da subarvore
// dp_1[0] >= dp_1[1] >= dp_1[2]
dp_1[v][0] = dp_1[v][1] = dp_1[v][2] = {-1e9, 0};
if(adj[v].size() == 1) dp_1[v][0] = {0, 1};
for(auto u : adj[v]){
if(u != p){
dfs_1(u, v);
pair<ll, ll> cur = dp_1[u][0];
for(int i=1; i<3; i++){
if(dp_1[u][i].first == dp_1[u][0].first){
cur.second += dp_1[u][i].second;
}
}
cur.first ++;
if(cur.first >= dp_1[v][0].first){
// {cur, dp_1[v][0], dp_1[v][1]}
dp_1[v][2] = dp_1[v][1];
dp_1[v][1] = dp_1[v][0];
dp_1[v][0] = cur;
} else if(cur.first >= dp_1[v][1].first){
// {dp_1[v][0], cur, dp_1[v][1]}
dp_1[v][2] = dp_1[v][1];
dp_1[v][1] = cur;
} else dp_1[v][2] = max(dp_1[v][2], cur);
}
}
// cout << v << " -> \n";
// cout << "(" << dp_1[v][0].first << ", " << dp_1[v][0].second << ")\n";
// cout << "(" << dp_1[v][1].first << ", " << dp_1[v][1].second << ")\n";
// cout << "(" << dp_1[v][2].first << ", " << dp_1[v][2].second << ")\n";
}
void dfs_2(int v, int p){
// dfs_2 calcula o mais distante fora da subarvore do v
// cout << v << " -> (" << dp_2[v].first << ", " << dp_2[v].second << ")\n";
map<ll, ll> freq;
for(auto u : adj[v]){
if(u != p){
for(int i=0; i<3; i++){
if(dp_1[u][i].first == dp_1[u][0].first){
freq[dp_1[u][0].first] += dp_1[u][i].second;
}
}
}
}
for(auto u : adj[v]){
if(u != p){
for(int i=0; i<3; i++){
if(dp_1[u][i].first == dp_1[u][0].first){
freq[dp_1[u][0].first] -= dp_1[u][i].second;
}
}
if(freq[dp_1[u][0].first] == 0) freq.erase(dp_1[u][0].first);
dp_2[u] = {dp_2[v].first + 1, dp_2[v].second};
if(!freq.empty()){
auto [d, cnt] = *freq.rbegin();
if(d + 2 > dp_2[u].first){
dp_2[u] = {d + 2, cnt};
} else if(d + 2 == dp_2[u].first) dp_2[u].second += cnt;
}
for(int i=0; i<3; i++){
if(dp_1[u][i].first == dp_1[u][0].first){
freq[dp_1[u][0].first] += dp_1[u][i].second;
}
}
}
}
freq.clear();
for(auto u : adj[v]) if(u != p) dfs_2(u, v);
}
int32_t main(){
cin.tie(0)->sync_with_stdio(0);
int n; cin >> n;
for(int i=1; i<n; i++){
int a, b; cin >> a >> b;
adj[a].push_back(b);
adj[b].push_back(a);
}
// cout << "checking dfs_1\n";
dfs_1(1, 1); // ok!
dp_2[1] = {0, 1};
// cout << "checking dfs_2\n";
dfs_2(1, 1); // ok!
pair<ll, ll> ans = {0, 0};
for(int i=1; i<=n; i++){
vector<pair<ll, ll>> d = {dp_1[i][0], dp_1[i][1], dp_1[i][2], dp_2[i]};
sort(d.rbegin(), d.rend());
ll a1 = d[0].first, b1 = d[1].first, c1 = d[2].first;
ll a2 = d[0].second, b2 = d[1].second, c2 = d[2].second;
pair<ll, ll> cur = {0, 0};
cur.first = (b1 + c1) * a1;
if(a1 == b1 && b1 == c1){
// todo mundo igual
cur.second = 3 * a2 * b2 * c2;
} else if(a1 == b1){
// dois primeiros
cur.second = 2 * a2 * b2 * c2;
} else if(b1 == c1){
// dois ultimos
cur.second = a2 * b2 * c2;
} else{
// ninguem
cur.second = a2 * b2 * c2;
}
ans = max(ans, cur);
}
cout << ans.first << " " << ans.second << "\n";
return 0;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |