Submission #1281687

#TimeUsernameProblemLanguageResultExecution timeMemory
1281687mwenHard route (IZhO17_road)C++20
100 / 100
569 ms234376 KiB
#include <bits/stdc++.h>

using namespace std;
typedef long long ll;
#define sz(x) (int)(x).size()
#define all(x) begin(x), end(x)

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int n;
    cin >> n;
    vector<vector<int>> con(n);
    for(int i = 0; i < n-1; i++) {
        int a, b;
        cin >> a >> b;
        a--; b--;
        con[a].push_back(b);
        con[b].push_back(a);
    }
    vector<pair<int, int>> d(n, {0, 1});
    auto dfs = [&](auto&& self, int curr, int prev) -> void {
        for(int next : con[curr]) {
            if(next == prev)
                continue;
            self(self, next, curr);
            if(d[next].first+1 > d[curr].first) {
                d[curr] = d[next];
                d[curr].first++;
            }
            else if(d[next].first+1 == d[curr].first) {
                d[curr].second += d[next].second;
            }
        }
    };
    dfs(dfs, 0, -1);
    
    ll best = 0, numWays = 1;
    auto updateBest = [&](ll v, ll ways) {
        if(v > best) {
            best = v;
            numWays = ways;
        }
        else if(v == best) {
            numWays += ways;
        }
    };
    
    auto reroot = [&](auto&& self, int curr, int prev, pair<ll, ll> par) -> void {
        vector<ll> largest3(3, -1), distCnts(3), subtreeCnts(3), overCount(3);
        auto update = [&](ll dist, ll cnt) {
            for(int i = 0; i < 3; i++) {
                if(dist > largest3[i]) {
                    for(int j = sz(largest3)-1; j > i; j--) {
                        largest3[j] = largest3[j-1];
                        distCnts[j] = distCnts[j-1];
                        subtreeCnts[j] = subtreeCnts[j-1];
                        overCount[j] = overCount[j-1];
                    }
                    largest3[i] = dist;
                    distCnts[i] = cnt;
                    subtreeCnts[i] = 1;
                    overCount[i] = cnt*(cnt-1)/2;
                    break;
                }
                else if(dist == largest3[i]) {
                    distCnts[i] += cnt;
                    subtreeCnts[i]++;
                    overCount[i] += cnt*(cnt-1)/2;
                    break;
                }
            }
        };
        auto [parDist, parCnt] = par;
        assert(parCnt > 0);
        update(parDist, parCnt);
        for(int next : con[curr]) {
            if(next == prev)
                continue;
            auto [nextDist, nextCnt] = d[next];
            update(nextDist+1, nextCnt);
        }
        //4 cases
        //Let A>B>C be the furthest distances of subtrees from this node
        //We look at the number of subtrees with each distance
        //A B C
        //3
        //2 x
        //1 x
        //1 1 x
        if(sz(con[curr]) >= 3) {
            ll A = largest3[0];
            ll B = largest3[1];
            ll C = largest3[2];
            ll aTrees = subtreeCnts[0];
            ll bTrees = subtreeCnts[1];
            ll cTrees = subtreeCnts[2];
            ll aCnt = distCnts[0];
            ll bCnt = distCnts[1];
            ll cCnt = distCnts[2];
            ll aOver = overCount[0];
            ll bOver = overCount[1];
            // ll cOver = overCount[2];
            ll v, ways;
            //3
            if(aTrees >= 3) {
                v = (A+A)*A;
                ways = aCnt*(aCnt-1)/2-aOver;
            }
            //2 x
            else if(aTrees == 2) {
                assert(bTrees >= 1);
                v = (A+B)*A;
                ways = aCnt*bCnt;
            }
            //1 x
            else if(aTrees == 1 && bTrees >= 2) {
                v = (B+B)*A;
                ways = bCnt*(bCnt-1)/2-bOver;
            }
            //1 1 x
            else {
                assert(aTrees == 1 && bTrees == 1 && cTrees >= 1);
                v = (B+C)*A;
                ways = bCnt*cCnt;
            }
            updateBest(v, ways);
        }
        
        for(int next : con[curr]) {
            if(next == prev)
                continue;
            auto [nextDist, nextCnt] = d[next];
            if(largest3[0] == nextDist+1) {
                if(distCnts[0] == nextCnt)
                    self(self, next, curr, {largest3[1]+1, distCnts[1]});
                else
                    self(self, next, curr, {largest3[0]+1, distCnts[0]-nextCnt});
            }
            else {
                self(self, next, curr, {largest3[0]+1, distCnts[0]});
            }
        }
    };
    reroot(reroot, 0, -1, {0, 1});
    cout << best << " " << numWays << "\n";
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...