#include <bits/stdc++.h>
using namespace std;
using ll = long long;
using pll = pair<ll, ll>;
#define N 500001
ll n;
pll dp[N], dp2[N], ans;
vector<int> adj[N];
vector<pll> a;
inline pll operator+(const pll& l, const ll r) {
return {l.first + r, l.second};
}
inline pll& operator+=(pll& l, const pll& r) {
if (r.first > l.first) l = r;
else if (r.first == l.first) l.second += r.second;
return l;
}
ostream& operator<<(ostream& stream, const pll& p) {
stream << p.first;
stream << ' ';
stream << p.second;
return stream;
}
void dfs(int u, int p) {
for (auto v: adj[u]) if (v != p) dfs(v, u);
for (auto v: adj[u]) if (v != p) dp[u] += (dp[v] + 1);
}
void dfs2(int u, int p) {
pll left = dp2[u] + 1;
vector<pll> right;
for (auto v: adj[u]) if (v != p) right.push_back(dp[v] + 2);
for (int i = right.size() - 2; i >= 0; i--) right[i] += right[i+1];
int i = 1;
for (auto v: adj[u]) if (v != p) {
pll cur = left;
if (i < right.size()) cur += right[i];
dp2[v] += cur;
left += (dp[v] + 2);
i++;
}
right.clear();
for (auto v: adj[u]) if (v != p) dfs2(v, u);
}
void dfs3(int u, int p) {
for (auto v: adj[u]) if (v != p) dfs3(v, u);
a.clear();
for (auto v: adj[u]) if (v != p) a.push_back(dp[v] + 1);
if (u != 1) a.push_back(dp2[u]);
if (a.size() < 3) return;
sort(a.begin(), a.end(), greater<>());
// cout << u << '\n'; for (auto e: a) cout << e << ' '; cout << '\n';
if (a[1].first != a[2].first) {
ll n1 = 0, n2 = 0;
for (auto [x, y]: a) {
if (x == a[1].first) n1 += a[1].second;
else if (x == a[2].first) n2 += a[2].second;
else if (x < a[2].first) break;
}
ans += {a[0].first*(a[1].first+a[2].first), n1*n2};
} else {
ll n1 = 0, n2 = 0;
for (auto [x, y]: a) {
if (x == a[1].first) n2 += n1*y, n1 += y;
else if (x < a[1].first) break;
}
ans += {a[0].first*a[1].first*2ll, n2};
}
}
int main() {
ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0);
cin >> n;
for (int i = 1; i < n; i++) {
static int u, v;
cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
vector<int> temp(3, 0);
for (int i = 1; i <= n; i++) temp[adj[i].size()]++;
if (temp[1] == 2 && temp[2] == n-2) {cout << 0 << ' ' << 1 << '\n'; return 0;}
dp2[1].second = (adj[1].size() == 1);
for (int i = 2; i <= n; i++) if (adj[i].size() == 1) dp[i].second = 1;
dfs(1, 0);
dfs2(1, 0);
dfs3(1, 0);
cout << ans;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |