이 제출은 이전 버전의 oj.uz에서 채점하였습니다. 현재는 제출 당시와는 다른 서버에서 채점을 하기 때문에, 다시 제출하면 결과가 달라질 수도 있습니다.
#include <iostream>
#include <vector>
#include <algorithm>
#include <stack>
#include <queue>
#include <cmath>
#include <iomanip>
#include <unordered_map>
#include <utility>
#include <assert.h>
using namespace std;
void debug() {cout << endl;}
template <class T, class ...U> void debug(T a, U ... b) { cout << a << " "; debug(b...);}
template <class T> void pary(T l, T r) {
while (l != r) {cout << *l << " ";l++;}
cout << endl;
}
#define ll long long
#define ld long double
#define maxn 500005
#define mod 1000000007
#define pii pair<int, int>
#define ff first
#define ss second
#define io ios_base::sync_with_stdio(0);cin.tie(0);
vector<int> adj[maxn];
int f[maxn][3], cnt[maxn][2], deg[maxn], up[maxn], down[maxn], g[maxn][3];
void upd(int x, int n, int y) {
if (x > f[n][0]) {
f[n][2] = f[n][1], f[n][1] = f[n][0], f[n][0] = x;
g[n][2] = g[n][1], g[n][1] = g[n][0], g[n][0] = y;
} else if (x > f[n][1]){
f[n][2] = f[n][1], f[n][1] = x;
g[n][2] = g[n][1], g[n][1] = y;
} else if (x > f[n][2]){
f[n][2] = x, g[n][2] = y;
}
}
void dfs(int n, int par) {
for (int v:adj[n]) {
if (v != par) {
dfs(v, n);
int x = f[v][0] + 1;
if (x > f[n][0]) cnt[n][0] = cnt[v][0];
else if (x == f[n][0]) cnt[n][0] += cnt[v][0];
upd(x, n, cnt[v][0]);
}
}
down[n] = f[n][0];
if (deg[n] == 1) cnt[n][0]++;
}
void dfs2(int n, int par, int d) {
upd(d, n, cnt[n][1]);
up[n] = d;
int best = d, bcnt = cnt[n][1], sec = 0, scnt = 0;
for (int v:adj[n]) {
if (v != par) {
int x = f[v][0] + 1;
if (x > best) sec = best, scnt = bcnt, best = x, bcnt = cnt[v][0];
else if (x == best) bcnt += cnt[v][0];
else if (x > sec) sec = x, scnt = cnt[v][0];
else if (x == sec) scnt += cnt[v][0];
}
}
for (int v:adj[n]) {
if (v != par) {
if (f[v][0] + 1 == best) {
if (bcnt > cnt[v][0]) cnt[v][1] = bcnt - cnt[v][0], dfs2(v, n, best + 1);
else cnt[v][1] = scnt, dfs2(v, n, sec + 1);
} else {
cnt[v][1] = bcnt;
dfs2(v, n, best + 1);
}
}
}
}
ll ans = 0, num = 0;
void solve(int n, int par) {
vector<int> arr;
for (int v:adj[n]) {
if (v != par) {
if (down[v] + 1 == f[n][2]) arr.push_back(cnt[v][0]);
}
}
if (par && up[n] == f[n][2]) arr.push_back(cnt[n][1]);
ll tot = 0, nums = 0;
for (int i:arr) tot += i;
//debug(n);
//pary(arr.begin(), arr.end());
if (f[n][0] && f[n][1] && f[n][2]) {
if (f[n][0] == f[n][1] && f[n][1] == f[n][2]) {
ll p1 = 0;
for (int i:arr) {
nums += p1 * i;
p1 += i;
}
} else if (f[n][0] == f[n][1]) {
nums += tot * (g[n][1] + g[n][0]);
} else if (f[n][1] == f[n][2]) {
ll p1 = 0;
for (int i:arr) nums += p1 * i, p1 += i;
} else {
nums += tot * g[n][1];
}
tot = (ll)f[n][0] * (f[n][1] + f[n][2]);
//debug(n, tot, nums);
if (tot > ans) ans = tot, num = nums;
else if (ans == tot) num += nums;
}
for (int v:adj[n]) {
if (v != par) solve(v, n);
}
}
int main() {
io
int n;
cin >> n;
for (int i = 0;i < n - 1;i++) {
int u, v;
cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
deg[u]++, deg[v]++;
}
int root = 0;
for (int i = 1;i <= n;i++) {
if (deg[i] > 2) {
root = i;
break;
}
}
if (!root) {
cout << 0 << " " << 1 << endl;
return 0;
}
dfs(root, 0);
dfs2(root, 0, 0);
//debug(root);
//for (int i = 1;i <= n;i++) debug(f[i][0], f[i][1], f[i][2]);
solve(root, 0);
cout << ans << " " << num << endl;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |