#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define pii pair<ll, ll>
#define fi first
#define sec second
#define ld long double
const int MAXN = 2e5;
const ll INF = 1e18;
const int MOD = 998244353;
const ld eps = 1e-6;
ll vis[MAXN + 5], cycle[MAXN + 5];
vector<ll> adj[MAXN + 5];
ll p[MAXN + 5], st, en;
ll d[MAXN + 5];
pii dp[MAXN + 5];
ll cnt, cntt, ans, cur;
void dfs2(ll idx, ll par){
vis[idx] = 1;
dp[idx] = {d[idx], 1};
ll MX = d[idx], MX2 = -INF;
for(auto i : adj[idx]){
if(!cycle[i] && i != par){
d[i] = d[idx] + 1;
dfs2(i, idx);
if(dp[idx].fi < dp[i].fi) dp[idx] = dp[i];
else if(dp[idx].fi == dp[i].fi) dp[idx].sec += dp[i].sec;
MX2 = max(MX2, dp[i].fi);
if(MX < MX2) swap(MX, MX2);
}
}
vector<ll> v;
cnt = (MX == d[idx]), cntt = 0;
for(auto i : adj[idx]){
if(!cycle[i] && i != par){
if(dp[i].fi == MX){
v.push_back(dp[i].sec);
cnt += dp[i].sec;
}
if(dp[i].fi == MX2) cntt += dp[i].sec;
}
}
if(v.size() > 1){
if(cur <= (MX - d[idx]) * 2){
if(cur < (MX - d[idx]) * 2){
cur = (MX - d[idx]) * 2;
ans = 0;
}
for(auto x : v){
ans += (cnt - x) * x;
}
}
}
else{
if(!cntt){
if(cur <= MX - d[idx]){
if(cur < MX - d[idx]){
cur = MX - d[idx];
ans = 0;
}
ans += cnt * 2;
}
}
else{
if(cur <= MX + MX2 - 2 * d[idx]){
if(cur < MX + MX2 - 2 * d[idx]){
cur = MX + MX2 - 2 * d[idx];
ans = 0;
}
ans += cnt * cntt * 2;
}
}
}
}
void dfs(ll idx, ll par){
vis[idx] = 1;
for(auto i : adj[idx]){
if(i == par) continue;
if(vis[i]) st = i, en = idx;
else{
p[i] = idx;
dfs(i, idx);
}
}
}
struct node{
ll MX, cnt;
};
struct ST{
ll N;
vector<node> sg;
ST(ll _n){
N = _n;
sg.resize(4 * N + 5);
}
node comb(node a, node b){
node c;
if(a.MX == b.MX) c = {a.MX, a.cnt + b.cnt};
else if(a.MX > b.MX) c = a;
else c = b;
return c;
}
void upd(ll l, ll r, ll cur, ll idx, node val){
if(l == r){
sg[cur] = val;
}
else{
ll mid = (l + r) / 2;
if(idx <= mid) upd(l, mid, cur * 2, idx, val);
else upd(mid + 1, r, cur * 2 + 1, idx, val);
sg[cur] = comb(sg[cur * 2], sg[cur * 2 + 1]);
}
}
node query(ll l, ll r, ll cur, ll x, ll y){
if(l > y || r < x) return {-INF, 0};
if(l >= x && r <= y) return sg[cur];
ll mid = (l + r) / 2;
return comb(query(l, mid, cur * 2, x, y), query(mid + 1, r, cur * 2 + 1, x, y));
}
};
int main(){
ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0);
int tc = 1;
// cin >> tc;
for(;tc--;){
ll N; cin >> N;
for(int i = 1; i <= N; i++){
ll u, v; cin >> u >> v;
adj[u].push_back(v), adj[v].push_back(u);
}
dfs(1, -1);
ll sz = 1;
d[st] = 0;
while(st != en){
cycle[st] = 1;
d[p[st]] = d[st] + 1;
sz++;
st = p[st];
}
cycle[en] = 1;
ST pos(sz), neg(sz);
for(int i = 1; i <= N; i++) vis[i] = 0;
vector<pair<ll, pii>> isi;
for(int tt = 1; tt <= N; tt++){
if(cycle[tt]){
dfs2(tt, -1);
ll i = d[tt], j = dp[tt].fi - i;
isi.push_back({i, {j, cnt}});
pos.upd(0, sz - 1, 1, i, {dp[tt].fi, dp[tt].sec});
neg.upd(0, sz - 1, 1, i, {dp[tt].fi - 2 * i, dp[tt].sec});
}
}
for(auto [i, val] : isi){
ll j = val.fi, cur_cnt = val.sec;
ll batas = i - sz / 2;
if(batas < 0){
node now = neg.query(0, sz - 1, 1, 0, i - 1);
if(cur < now.MX + i + j){
ans = now.cnt * cur_cnt;
cur = now.MX + i + j;
}
else if(cur == now.MX + i + j){
ans += now.cnt * cur_cnt;
}
now = neg.query(0, sz - 1, 1, sz + batas, sz - 1);
if(cur < now.MX + sz + i + j){
ans = now.cnt * cur_cnt;
cur = now.MX + sz + i + j;
}
else if(cur == now.MX + sz + i + j){
ans += now.cnt * cur_cnt;
}
now = pos.query(0, sz - 1, 1, i + 1, sz + batas - 1);
if(cur < now.MX - i + j){
ans = now.cnt * cur_cnt;
cur = now.MX - i + j;
}
else if(cur == now.MX - i + j){
ans += now.cnt * cur_cnt;
}
if(sz % 2 == 0){
now = pos.query(0, sz - 1, 1, i - sz / 2 + sz, i - sz / 2 + sz);
if(cur < now.MX - i + j){
ans = now.cnt * cur_cnt;
cur = now.MX - i + j;
}
else if(cur == now.MX - i + j){
ans += now.cnt * cur_cnt;
}
}
}
else{
node now = neg.query(0, sz - 1, 1, batas, i - 1);
if(cur < now.MX + i + j){
ans = now.cnt * cur_cnt;
cur = now.MX + i + j;
}
else if(cur == now.MX + i + j){
ans += now.cnt * cur_cnt;
}
now = pos.query(0, sz - 1, 1, 0, batas - 1);
if(cur < now.MX + sz - i + j){
ans = now.cnt * cur_cnt;
cur = now.MX + sz - i + j;
}
else if(cur == now.MX + sz - i + j){
ans += now.cnt * cur_cnt;
}
now = pos.query(0, sz - 1, 1, i + 1, sz);
if(cur < now.MX - i + j){
ans = now.cnt * cur_cnt;
cur = now.MX - i + j;
}
else if(cur == now.MX - i + j){
ans += now.cnt * cur_cnt;
}
if(sz % 2 == 0){
now = neg.query(0, sz - 1, 1, i - sz / 2, i - sz / 2);
if(cur < now.MX + i + j){
ans = now.cnt * cur_cnt;
cur = now.MX + i + j;
}
else if(cur == now.MX + i + j){
ans += now.cnt * cur_cnt;
}
}
}
}
cout << ans / 2 << "\n";
}
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |