#include <bits/stdc++.h>
using namespace std;
#define ll long long
const int N = 1e5 + 10;
int n, k;
vector<int> g[N];
int pa[N];
long long d_pa[N];
pair<int, long long> md[N][3];
long long ans;
int root;
void dfs(int u, int p) {
pa[u] = p;
for (int i = 0; i < 3; i ++) md[u][i] = {u, 0};
for (int v : g[u]) if (v != p) {
dfs(v, u);
//cout << u << " - " << v << " : " << md[v][0].second << "\n";
int id = 3;
for (int i = 0; i < 3; i ++) {
if (md[u][i].second < md[v][0].second + 1) {
id = i;
break;
}
}
//cout << u << " - " << v << " : " << id << " , " << md[v][0].second + 1 << "\n";
if (id < 3) {
for (int i = 2; i > id; i --) {
md[u][i] = md[u][i - 1];
}
md[u][id] = {v, md[v][0].second + 1};
//cout << u << " " << id << " : " << v << " " << md[u][id].second << "\n";
}
}
/*
cout << u << "\n";
for (int j = 0; j < 3; j ++) {
cout << md[u][j].first << " " << md[u][j].second << "\n";
}
cout << "\n";
*/
}
void dfs2(int u, int p) {
for (int v : g[u]) if (v != p) {
if (v == md[u][0].first) {
d_pa[v] = max(d_pa[u] + 1, md[u][1].second + 1);
} else {
d_pa[v] = max(d_pa[u] + 1, md[u][0].second + 1);
}
dfs2(v, u);
}
}
void solve() {
dfs(1, 1);
dfs2(1, 1);
for (int i = 1; i <= n; i ++) {
//cout << i << "\n";
for (int j = 0; j < 3; j ++) {
//cout << md[i][j].first << " " << md[i][j].second << "\n";
}
//cout << "\n";
long long cur = 0;
if (md[i][1].second == 0 || (md[i][2].second == 0 && d_pa[i] == 0)) continue;
cur = max(cur, md[i][0].second * (md[i][1].second + md[i][2].second));
cur = max(cur, md[i][0].second * (md[i][1].second + d_pa[i]));
cur = max(cur, d_pa[i] * (md[i][0].second + md[i][1].second));
if (cur >= ans) {
root = i;
ans = cur;
}
}
}
int sz[N];
long long cnt[N];
long long total;
void predfs(int u, int p) {
sz[u] = 1;
for (int v : g[u]) if (v != p) {
predfs(v, u);
sz[u] += sz[v];
}
//cout << u << " : " << sz[u] << "\n";
}
void update(int u, int p, int cur_d, int val) {
if (sz[u] == 1) cnt[cur_d] += val;
for (int v : g[u]) if (v != p) {
update(v, u, cur_d + 1, val);
}
}
void dfs3(int u, int p, int cur_d) {
if (sz[u] == 1) {
cnt[cur_d] ++;
return;
}
int b = 0;
for (int v : g[u]) if (v != p) {
if (b == 0 || sz[v] > sz[b]) b = v;
}
for (int v : g[u]) if (v != p && v != b) {
dfs3(v, u, cur_d + 1);
update(v, u, cur_d + 1, - 1);
}
dfs3(b, u, cur_d + 1);
for (int v : g[u]) if (v != p && v != b) {
update(v, u, cur_d + 1, 1);
}
/*
cout << u << "\n";
for (int i = 0; i <= n; i ++) {
cout << i << " : " << cnt[i] << "\n";
}
cout << "\n";
*/
long long x[3];
for (int i = 0; i < 3; i ++) x[i] = md[u][i].second;
if (x[0] * (x[1] + x[2]) < ans && d_pa[u] * (x[0] + x[1]) < ans) return;
else if (d_pa[u] * (x[0] + x[1]) < ans) {
if (x[1] == x[0]) {
if (x[2] == x[1]) {
total += (cnt[x[1] + cur_d] - 1) * (cnt[x[1] + cur_d] - 2) / 2;
} else {
total += (cnt[x[1] + cur_d] - 1) * cnt[x[2] + cur_d];
}
} else {
if (x[2] == x[1]) {
total += (cnt[x[1] + cur_d] - 1) * cnt[x[1] + cur_d] / 2;
} else {
total += cnt[x[1] + cur_d] * cnt[x[2] + cur_d];
}
}
} else {
if (x[1] == x[0]) {
total += cnt[x[0] + cur_d] * (cnt[x[0] + cur_d] - 1) / 2;
} else {
total += cnt[x[0] + cur_d] * cnt[x[1] + cur_d];
}
}
}
int main() {
ios_base::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
cin >> n;
int u, v;
for (int i = 1; i < n; i ++) {
cin >> u >> v;
g[u].push_back(v);
g[v].push_back(u);
}
root = 1;
solve();
cout << ans << "\n";
//cout << root << "\n";
if (ans == 0) {
cout << 1 << "\n";
return 0;
}
predfs(root, root);
dfs3(root, root, 0);
cout << total << "\n";
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |