제출 #1281662

#제출 시각아이디문제언어결과실행 시간메모리
1281662mwenHard route (IZhO17_road)C++20
0 / 100
1 ms576 KiB
#include <bits/stdc++.h>

using namespace std;
typedef long long ll;
#define sz(x) (int)(x).size()
#define all(x) begin(x), end(x)

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int n;
    cin >> n;
    vector<vector<int>> con(n);
    for(int i = 0; i < n-1; i++) {
        int a, b;
        cin >> a >> b;
        a--; b--;
        con[a].push_back(b);
        con[b].push_back(a);
    }
    vector<ll> d(n);
    function<void(int, int)> dfs = [&](int curr, int prev) {
        for(int next : con[curr]) {
            if(next == prev)
                continue;
            dfs(next, curr);
            d[curr] = max(d[curr], d[next]+1);
        }
    };
    dfs(0, -1);
    
    ll best = 0, cnt = 1;
    auto updateBest = [&](ll v, ll ways) {
        if(v > best) {
            best = v;
            cnt = ways;
        }
        else if(v == best) {
            cnt += ways;
        }
    };
    
    function<void(int, int, ll)> reroot = [&](int curr, int prev, ll parDist) {
        map<ll, ll> dists;
        dists[parDist]++;
        for(int next : con[curr]) {
            if(next == prev)
                continue;
            dists[d[next]+1]++;
        }
        //4 cases
        //Let A>B>C be the furthest distances from this node
        //We look at the counts of each
        //A B C
        //3
        //2 x
        //1 x
        //1 1 x
        if(sz(con[curr]) >= 3) {
            auto it = dists.rbegin();
            auto [A, cnt1] = *it;
            ll v, ways;
            //3
            if(cnt1 >= 3) {
                v = (A+A)*A;
                ways = cnt1*(cnt1-1)/2;
            }
            else {
                it++;
                auto [B, cnt2] = *it;
                if(cnt1+cnt2 >= 3) {
                    //2 x
                    if(cnt1 == 2) {
                        v = (A+B)*A;
                        ways = cnt1*cnt2;
                    }
                    //1 x
                    else {
                        v = (B+B)*A;
                        ways = cnt2*(cnt2-1)/2;
                    }
                }
                else {
                    it++;
                    auto [C, cnt3] = *it;
                    assert(cnt1 == 1 && cnt2 == 1);
                    v = (B+C)*A;
                    ways = cnt3;
                }
            }
            updateBest(v, ways);
        }
        
        for(int next : con[curr]) {
            if(next == prev)
                continue;
            auto it = dists.rbegin();
            if(it->second == 1 && d[next]+1 == it->first) {
                it--;
                reroot(next, curr, it->first+1);
                it++;
            }
            else {
                reroot(next, curr, it->first+1);
            }
        }

    };
    reroot(0, -1, 0);
    cout << best << " " << cnt << "\n";
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...