제출 #1281670

#제출 시각아이디문제언어결과실행 시간메모리
1281670mwenHard route (IZhO17_road)C++20
52 / 100
504 ms271900 KiB
#include <bits/stdc++.h> using namespace std; typedef long long ll; #define sz(x) (int)(x).size() #define all(x) begin(x), end(x) int main() { ios::sync_with_stdio(false); cin.tie(nullptr); int n; cin >> n; vector<vector<int>> con(n); for(int i = 0; i < n-1; i++) { int a, b; cin >> a >> b; a--; b--; con[a].push_back(b); con[b].push_back(a); } vector<pair<ll, ll>> d(n, {0, 1}); function<void(int, int)> dfs = [&](int curr, int prev) { for(int next : con[curr]) { if(next == prev) continue; dfs(next, curr); if(d[next].first+1 > d[curr].first) { d[curr] = d[next]; d[curr].first++; } else if(d[next].first+1 == d[curr].first) { d[curr].second += d[next].second; } } }; dfs(0, -1); ll best = 0, cnt = 1; auto updateBest = [&](ll v, ll ways) { if(v > best) { best = v; cnt = ways; } else if(v == best) { cnt += ways; } }; function<void(int, int, pair<ll, ll>)> reroot = [&](int curr, int prev, pair<ll, ll> par) { map<ll, ll> distCnts, subtreeCnts, overCount; auto [parDist, parCnt] = par; distCnts[parDist] += parCnt; subtreeCnts[parDist]++; overCount[parDist] += parCnt*(parCnt-1)/2; for(int next : con[curr]) { if(next == prev) continue; auto [nextDist, nextCnt] = d[next]; distCnts[nextDist+1] += nextCnt; subtreeCnts[nextDist+1]++; overCount[nextDist+1] += nextCnt*(nextCnt-1)/2; } //4 cases //Let A>B>C be the furthest distances of subtrees from this node //We look at the number of subtrees with each distance //A B C //3 //2 x //1 x //1 1 x if(sz(con[curr]) >= 3) { auto it = subtreeCnts.rbegin(); auto [A, aCnt] = *it; ll cnt1 = distCnts[A]; ll v, ways; //3 if(aCnt >= 3) { v = (A+A)*A; ways = cnt1*(cnt1-1)/2-overCount[A]; } else { it++; auto [B, bCnt] = *it; ll cnt2 = distCnts[B]; //2 x if(aCnt == 2) { assert(aCnt == 2 && bCnt >= 1); v = (A+B)*A; ways = cnt1*cnt2; } else { //1 x if(bCnt >= 2) { assert(aCnt == 1 && bCnt >= 1); v = (B+B)*A; ways = cnt2*(cnt2-1)/2-overCount[B]; } //1 1 x else { it++; auto [C, cCnt] = *it; ll cnt3 = distCnts[C]; assert(aCnt == 1 && bCnt == 1 && cCnt >= 1); v = (B+C)*A; ways = cnt2*cnt3; } } } updateBest(v, ways); } for(int next : con[curr]) { if(next == prev) continue; auto it = distCnts.rbegin(); auto [nextDist, nextCnt] = d[next]; if(it->first == nextDist+1 && it->second == nextCnt) { it++; reroot(next, curr, {it->first+1, it->second}); it--; } else { reroot(next, curr, {it->first+1, it->second-nextCnt}); } } }; reroot(0, -1, {0, 1}); cout << best << " " << cnt << "\n"; }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...