이 제출은 이전 버전의 oj.uz에서 채점하였습니다. 현재는 제출 당시와는 다른 서버에서 채점을 하기 때문에, 다시 제출하면 결과가 달라질 수도 있습니다.
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define c2(x) (x*(x-1))/2
struct dp{
ll tot, sep, badComb, d;
};
const int mx = 5e5+5;
int n, root = -1; vector<int> adj[mx];
dp in[mx][3], out[mx]; pair<ll, ll> ans;
void updIn(int c, int val, int inc){
if (val == in[c][0].d) in[c][0].tot += inc, in[c][0].sep++, in[c][0].badComb += c2(inc);
else if (val == in[c][1].d) in[c][1].tot += inc, in[c][1].sep++, in[c][1].badComb += c2(inc);
else if (val == in[c][2].d) in[c][2].tot += inc, in[c][2].sep++, in[c][2].badComb += c2(inc);
else if (val > in[c][0].d) in[c][2] = in[c][1], in[c][1] = in[c][0], in[c][0] = {inc, 1, c2(inc), val};
else if (val > in[c][1].d) in[c][2] = in[c][1], in[c][1] = {inc, 1, c2(inc), val};
else if (val > in[c][2].d) in[c][2] = {inc, 1, c2(inc), val};
}
void updOut(int node, int nxt){
dp cmpIn = in[node][0];
if (in[nxt][0].d+1 == in[node][0].d and in[node][0].sep == 1) cmpIn = in[node][1];
if (cmpIn.d > out[node].d) out[nxt] = {cmpIn.tot, 1, c2(cmpIn.tot), cmpIn.d+1};
else if (cmpIn.d < out[node].d) out[nxt] = {out[node].tot, 1, c2(out[node].tot), out[node].d+1};
else out[nxt] = {cmpIn.tot+out[node].tot, 1, c2(cmpIn.tot+out[node].tot), out[node].d+1};
}
void updAns(ll val, ll cnt){
if (val > ans.first) ans = {val, cnt};
else if (val == ans.first) ans.second += cnt;
}
void dfs1(int node, int p){
for (int nxt : adj[node])
if (nxt != p)
dfs1(nxt, node), updIn(node, in[nxt][0].d+1, in[nxt][0].tot);
if (adj[node].size() == 1) in[node][0].tot = 1;
}
void dfs2(int node, int p){
for (int nxt : adj[node])
if (nxt != p)
updOut(node, nxt), dfs2(nxt, node);
}
void solve(int node, int p){
dp cmp[3] = {in[node][0], in[node][1], in[node][2]};
for (int i = 0; i < 3; i++){
if (out[node].d == cmp[i].d and out[node].tot != 0){
cmp[i].tot += out[node].tot; cmp[i].sep += 1; cmp[i].badComb += out[node].badComb;
break;
}
if (out[node].d > cmp[i].d){
rotate(cmp+i, cmp+2, cmp+3); cmp[i] = out[node];
break;
}
}
if (cmp[0].sep > 2) updAns(cmp[0].d*(cmp[0].d*2), c2(cmp[0].tot)-cmp[0].badComb);
else if (cmp[0].sep == 2 and cmp[1].sep != 0) updAns(cmp[0].d*(cmp[0].d+cmp[1].d), cmp[0].tot*cmp[1].tot);
else if (cmp[0].sep == 1 and cmp[1].sep > 1) updAns(cmp[0].d*(cmp[1].d*2), c2(cmp[1].tot)-cmp[1].badComb);
else if (cmp[0].sep == 1 and cmp[1].sep == 1 and cmp[2].sep != 0) updAns(cmp[0].d*(cmp[1].d+cmp[2].d), cmp[1].tot*cmp[2].tot);
for (int nxt : adj[node])
if (nxt != p)
solve(nxt, node);
}
int main() {
ios_base::sync_with_stdio(0); cin.tie(0);
cin >> n;
for (int i = 0; i < n-1; i++){
int a, b; cin >> a >> b; a--; b--;
adj[a].push_back(b); adj[b].push_back(a);
}
for (int i = 0; i < n; i++) if (adj[i].size() > 2) { root = i; break; }
if (root == -1) { cout<<0<<" "<<1; return 0; }
fill(out, out+mx, dp({0, 0, -LLONG_MAX}));
dfs1(root, -1);
dfs2(root, -1);
solve(root, -1);
cout<<ans.first<<" "<<ans.second<<endl;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |