제출 #1131779

#제출 시각아이디문제언어결과실행 시간메모리
1131779birsnotHard route (IZhO17_road)C++20
52 / 100
804 ms320228 KiB
// author: Nardos Wehabe

#include "bits/stdc++.h"


#ifndef ONLINE_JUDGE
#include "./debug/debug.h"
#endif

using namespace std;
typedef long long ll;
typedef map<int, int, greater<>> mapg;
// typedef __int128 lll;

void chmx(int n, int& mx1, int& mx2) {
    if (n > mx1) {
        mx2 = mx1;
        mx1 = n;
    } else mx2 = max(n, mx2);
}

void chmx2(int n, int& mx1, int& mx2, int& mx3) {
    if (n > mx1) {
        mx3 = mx2;
        mx2 = mx1;
        mx1 = n;
    } else chmx(n, mx2, mx3);
}

void convert3(mapg& cnts, array<array<int, 2>, 3>& cur) {
    int i = 0;
    for (auto& [n, fq] : cnts) {
        cur[i++] = { n, fq };
        if (i >= 3) return;
    }
}
void convert6(mapg& cnts, mapg& cur) {
    int i = 0;
    for (auto& [n, fq] : cnts) {
        cur[n] = fq;
        if (++i >= 6) return;
    }
}

void incr(array<array<int, 2>, 3>& cur) {
    cur[0][0]++;
    cur[0][1] += cur[0][0] == 1;
    cur[1][0] += cur[1][0] != 0;
    cur[2][0] += cur[2][0] != 0;
}

void solve() {
    int N;
    cin >> N;
    vector<int> adj[N];
    for (int i = 0; i < N - 1; ++i) {
        int u, v;
        cin >> u >> v;
        u--, v--;
        adj[v].push_back(u);
        adj[u].push_back(v);
    }

    vector<array<array<int, 2>, 3>> mxs(N, { {0, 0} });

    function<array<array<int, 2>, 3>(int, int)> dfs1 = [&](int v, int p) {
        mapg cnts;
        for (auto ch : adj[v]) {
            if (ch == p) continue;
            auto ret = dfs1(ch, v);
            for (auto [n, fq] : ret) cnts[n] += fq;
        }
        convert3(cnts, mxs[v]);
        incr(mxs[v]);
        return mxs[v];
        };
    dfs1(0, -1);


    vector<array<ll, 2>> ans(N, { 0 });

    ll best = 0;

    function<void(int, int, array<array<int, 2>, 3>&)> dfs2 = [&](int v, int p, array<array<int, 2>, 3>& fr_p) {
        mapg cur;
        {
            mapg cnts;
            if (p != -1) mxs[p] = fr_p;

            int mx1 = 0, mx2 = 0, mx3 = 0;
            for (auto ch : adj[v]) {
                auto ret = mxs[ch];
                chmx2(ret[0][0], mx1, mx2, mx3);
                for (auto [n, fq] : ret) cnts[n] += fq;
            }
            convert6(cnts, cur);
            int cnt1 = cnts[mx1], cnt2 = cnts[mx2], cnt3 = cnts[mx3];
            ans[v][0] = (1ll * (mx2 + mx3) * mx1) * (mx3 > 0);
            best = max(ans[v][0], best);

            if (mx2 == mx3) {
                ans[v][1] = 1ll * cnt2 * (cnt2 - 1) / 2;

                for (auto ch : adj[v]) {
                    auto ret = mxs[ch];
                    int fq = 0;
                    for (auto [n, f] : ret) {
                        fq += f * (n == mx2);
                    }
                    if (ret[0][0] == mx1 && ret[0][1] == cnt1) {
                        ans[v][1] -= 1ll * fq * (cnt2 - fq);
                    }
                    ans[v][1] -= 1ll * fq * (fq - 1) / 2;
                }
            } else {
                ans[v][1] = 1ll * cnt2 * cnt3;
                for (auto ch : adj[v]) {
                    auto ret = mxs[ch];
                    int fq1 = 0, fq2 = 0;
                    for (auto [n, fq] : ret) {
                        fq1 += fq * (n == mx2);
                        fq2 += fq * (n == mx3);
                    }
                    if (ret[0][0] == mx1 && ret[0][1] == cnt1) {
                        ans[v][1] -= 1ll * fq1 * (cnt3 - fq2) + 1ll * fq2 * (cnt2 - fq1);
                    }
                    ans[v][1] -= 1ll * fq1 * fq2;
                }
            }
        }
        for (auto ch : adj[v]) {
            if (ch == p) continue;
            auto ret = mxs[ch];
            auto cur2 = cur;
            for (auto [n, fq] : ret) {
                cur2[n] -= fq;
                if (cur2[n] == 0) cur2.erase(n);
            }
            array<array<int, 2>, 3> to_ch = { {0, 0} };
            convert3(cur2, to_ch);
            incr(to_ch);
            dfs2(ch, v, to_ch);
        }
        };
    array<array<int, 2>, 3> ne = { {0, 0} };
    dfs2(0, -1, ne);

    if (best == 0) {
        cout << "0 1\n";
        return;
    }
    ll cnt = 0;
    for (auto [v, fq] : ans) {
        if (v == best) cnt += fq;
        assert(fq >= 0);
    }
    cout << best << " " << cnt << "\n";
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0);

    int tt = 1;
    // cin >> tt;

    while (tt--)
        solve();
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...