제출 #1281682

#제출 시각아이디문제언어결과실행 시간메모리
1281682mwenHard route (IZhO17_road)C++20
52 / 100
555 ms304508 KiB
#include <bits/stdc++.h>

using namespace std;
typedef long long ll;
#define sz(x) (int)(x).size()
#define all(x) begin(x), end(x)

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int n;
    cin >> n;
    vector<vector<int>> con(n);
    for(int i = 0; i < n-1; i++) {
        int a, b;
        cin >> a >> b;
        a--; b--;
        con[a].push_back(b);
        con[b].push_back(a);
    }
    vector<pair<int, int>> d(n, {0, 1});
    function<void(int, int)> dfs = [&](int curr, int prev) {
        for(int next : con[curr]) {
            if(next == prev)
                continue;
            dfs(next, curr);
            if(d[next].first+1 > d[curr].first) {
                d[curr] = d[next];
                d[curr].first++;
            }
            else if(d[next].first+1 == d[curr].first) {
                d[curr].second += d[next].second;
            }
        }
    };
    dfs(0, -1);
    
    ll best = 0, cnt = 1;
    auto updateBest = [&](ll v, ll ways) {
        if(v > best) {
            best = v;
            cnt = ways;
        }
        else if(v == best) {
            cnt += ways;
        }
    };
    
    function<void(int, int, pair<int, int>)> reroot = [&](int curr, int prev, pair<int, int> par) {
        map<int, pair<int, int>> distCnts;
        map<int, ll> overCount;
        auto [parDist, parCnt] = par;
        distCnts[parDist].first += parCnt;
        distCnts[parDist].second++;
        overCount[parDist] += (ll)parCnt*(parCnt-1)/2;
        for(int next : con[curr]) {
            if(next == prev)
                continue;
            auto [nextDist, nextCnt] = d[next];
            distCnts[nextDist+1].first += nextCnt;
            distCnts[nextDist+1].second++;
            overCount[nextDist+1] += (ll)nextCnt*(nextCnt-1)/2;
            while(sz(distCnts) > 3) {
                auto v = distCnts.begin()->first;
                distCnts.erase(v);
                overCount.erase(v);
            }
        }
        //4 cases
        //Let A>B>C be the furthest distances of subtrees from this node
        //We look at the number of subtrees with each distance
        //A B C
        //3
        //2 x
        //1 x
        //1 1 x
        if(sz(con[curr]) >= 3) {
            auto it = distCnts.rbegin();
            auto [A, aInfo] = *it;
            auto [cnt1, aCnt] = aInfo;
            ll v, ways;
            //3
            if(aCnt >= 3) {
                v = (ll)(A+A)*A;
                ways = (ll)cnt1*(cnt1-1)/2-overCount[A];
            }
            else {
                it++;
                auto [B, bInfo] = *it;
                auto [cnt2, bCnt] = bInfo;
                //2 x
                if(aCnt == 2) {
                    assert(aCnt == 2 && bCnt >= 1);
                    v = (ll)(A+B)*A;
                    ways = (ll)cnt1*cnt2;
                }
                else {
                    //1 x
                    if(bCnt >= 2) {
                        assert(aCnt == 1 && bCnt >= 1);
                        v = (ll)(B+B)*A;
                        ways = (ll)cnt2*(cnt2-1)/2-overCount[B];
                    }
                    //1 1 x
                    else {
                        it++;
                        auto [C, cInfo] = *it;
                        auto [cnt3, cCnt] = cInfo;
                        assert(aCnt == 1 && bCnt == 1 && cCnt >= 1);
                        v = (ll)(B+C)*A;
                        ways = (ll)cnt2*cnt3;
                    }
                }
            }
            updateBest(v, ways);
        }
        
        for(int next : con[curr]) {
            if(next == prev)
                continue;
            auto it = distCnts.rbegin();
            auto [nextDist, nextCnt] = d[next];
            if(it->first == nextDist+1 && it->second.first == nextCnt) {
                it++;
                reroot(next, curr, {it->first+1, it->second.first});
                it--;
            }
            else {
                reroot(next, curr, {it->first+1, it->second.first-nextCnt});
            }
        }
    };
    reroot(0, -1, {0, 1});
    cout << best << " " << cnt << "\n";
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...