#include <bits/stdc++.h>
#define all(x) x.begin(),x.end()
#define sz(x) ((int)x.size())
#define PB push_back
#define MP make_pair
#define ft first
#define sd second
#define pll pair<ll, ll>
using namespace std;
typedef long long ll;
const ll OO = 1e18;
const int N = 200100;
const int M = 1010;
//const int PW = (1 << N);
vector<int> vc, g[N], cyc;
int mrk[N], n, mx1[N], mx2[N], kl1[N], kl2[N], ans = -1;
ll kans = 0;
bool in_cyc[N];
void CYCLE(int lst){
cyc.clear();
int it = sz(vc) - 1;
while (1){
cyc.PB(vc[it]);
in_cyc[vc[it]] = 1;
if (vc[it] == lst)
break;
it--;
}
}
void dfs(int v, int p){
if (sz(cyc) > 0) return;
vc.PB(v);
mrk[v] = 1;
for (int u : g[v]){
if (u == p) continue;
if (mrk[u] == 0)
dfs(u, v); else
if (mrk[u] == 1)
CYCLE(u);
}
mrk[v] = 2;
vc.pop_back();
}
void calc_max(int v, int p){
mx1[v] = 0; mx2[v] = -1;
kl1[v] = 1; kl2[v] = 0;
int ko = 1;
ll res = 0, sm = 0;
for (int u : g[v]){
if (u == p || in_cyc[u]) continue;
calc_max(u, v);
if (mx1[v] == mx1[u] + 1){
kl1[v] += kl1[u];
ko++;
res += sm * ll(kl1[u]);
sm += kl1[u];
} else if (mx1[v] < mx1[u] + 1){
swap(kl1[v], kl2[v]);
swap(mx1[v], mx2[v]);
kl1[v] = kl1[u];
ko = 1;
res = kl1[u];
sm = kl1[u];
mx1[v] = mx1[u] + 1;
} else if (mx2[v] == mx1[u] + 1){
kl2[v] += kl1[u];
} else if (mx2[v] < mx1[u] + 1) {
kl2[v] = kl1[u];
mx2[v] = mx1[u] + 1;
}
}
if (mx1[v] < 1) return;
if (ko > 1){
if (mx1[v] + mx1[v] > ans){
ans = mx1[v] + mx1[v];
kans = 0;
}
if (mx1[v] + mx1[v] == ans)
kans += res;
} else {
if (mx1[v] + mx2[v] > ans){
ans = mx1[v] + mx2[v];
kans = 0;
}
if (mx1[v] + mx2[v] == ans)
kans += ll(kl1[v]) * ll(kl2[v]);
}
}
int main(){
ios_base::sync_with_stdio(0); cin.tie(0);
// freopen("penguins.in","r",stdin); freopen("penguins.out","w",stdout);
// freopen("in.txt","r",stdin);
cin >> n;
for (int i = 0; i < n; i++){
int x, y;
cin >> x >> y;
x--; y--;
g[x].PB(y);
g[y].PB(x);
}
dfs(0, -1);
for (int cr : cyc)
calc_max(cr, -1);
bool good = (sz(cyc) % 2 == 0);
for (int i1 = 0; i1 < sz(cyc); i1++)
for (int i2 = i1 + 1; i2 < sz(cyc); i2++){
int dst = i2 - i1, n1 = cyc[i1], n2 = cyc[i2];
dst = min(dst, sz(cyc) - dst);
dst += mx1[n1] + mx1[n2];
if (dst > ans){
ans = dst;
kans = 0;
}
if (dst == ans)
kans += ll(kl1[n1]) * ll(kl1[n2]);
if (good && (i2 - i1) == (sz(cyc) >> 1) && dst == ans)
kans += ll(kl1[n1]) * ll(kl1[n2]);
}
cout << kans << '\n';
return 0;
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
5 ms |
4984 KB |
Output is correct |
2 |
Incorrect |
6 ms |
4984 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
7 ms |
5112 KB |
Output is correct |
2 |
Incorrect |
7 ms |
5112 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
76 ms |
9984 KB |
Output is correct |
2 |
Correct |
69 ms |
10360 KB |
Output is correct |
3 |
Execution timed out |
1566 ms |
13172 KB |
Time limit exceeded |
4 |
Halted |
0 ms |
0 KB |
- |