제출 #1346706

#제출 시각아이디문제언어결과실행 시간메모리
1346706bbbirosHard route (IZhO17_road)C++20
0 / 100
2 ms344 KiB
#include <iostream>
#include <vector>
#include <cstring>
#include <string>
#include <iomanip>
#include <algorithm>
#include <fstream>
#include <cmath>
#include <unordered_set>
#include <set>
#include <unordered_map>
#include <map>
#define ll long long
#define X first
#define Y second
#define endl '\n'
using namespace std;

void speed()
{
    ios_base::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
}
const int MAXN = 500010;
int n;
vector<int> v[MAXN];
pair<int, ll> dp1[MAXN];
pair<int, ll> dp2[MAXN];
void mert(pair<int, ll> &x, pair<int, ll> y)
{
    if (x.X > y.X)
        return;
    if (x.X < y.X)
    {
        x = y;
        return;
    }
    x.Y += y.Y;
}
void read()
{
    cin >> n;
    for (int i = 1; i < n; i++)
    {
        int x, y;
        cin >> x >> y;
        v[x].push_back(y);
        v[y].push_back(x);
    }
}
pair<int, ll> operator+(pair<int, ll> x, int val)
{
    return {x.X + val, x.Y};
}
void dfs1(int beg, int par)
{
    bool l=0;
    for (int nb : v[beg])
    {
        if (nb == par)
            continue;
        dfs1(nb, beg);
        l=1;
        mert(dp1[beg], dp1[nb] + 1);
    }
    if (v[beg].size())
        mert(dp1[beg] , {0, 1});
}
int parr[MAXN];
void dfs2(int beg, int par)
{
    parr[beg] = par;
    vector<int> kids;
    for (int nb : v[beg])
    {
        if (nb == par)
            continue;
        kids.push_back(nb);
    }
    if (kids.empty())
        return;
    int sz = kids.size();
    vector<pair<int, ll>> pref(sz), sufx(sz);
    pref[0] = dp1[kids[0]] + 1;
    for (int i = 1; i < sz; i++)
    {
        pref[i] = pref[i - 1];
        mert(pref[i], dp1[kids[i]] + 1);
    }
    sufx.back() = dp1[kids.back()] + 1;
    for (int i = sz - 2; i >= 0; i--)
    {
        sufx[i] = sufx[i + 1];
        mert(sufx[i], dp1[kids[i]] + 1);
    }
    for (int i = 0; i < sz; i++)
    {
        pair<int, ll> val = dp2[beg];
        if (i - 1 >= 0)
            mert(val, pref[i - 1]);
        if (i + 1 < sz)
            mert(val, sufx[i + 1]);
        int nb = kids[i];
        dp2[nb] = val + 1;
        dfs2(nb, beg);
    }
}
vector<pair<int, ll>> ex(int x)
{
    vector<pair<int, ll>> ans;
    if (dp2[x].Y > 0)
        ans.push_back(dp2[x]);
    for (int nb : v[x])
    {
        if (nb == parr[x])
            continue;
        ans.push_back(dp1[nb] + 1);
    }
    sort(ans.begin(), ans.end(), [](auto &a, auto &b)
         { return a.X > b.X; });
    return ans;
}
signed main()
{
    speed();
    read();
    dp2[1] = {0, 0};
    dfs1(1, 0);
    dfs2(1, 0);
    ll ba = -1;
    for (int i = 1; i <= n; i++)
    {
        vector<pair<int, ll>> v = ex(i);
        if (v.size() < 3)
            continue;
        ba = max(ba, (ll)(v[0].X * (v[1].X + v[2].X)));
    }
    if (ba == -1)
    {
        cout << 0 << " " << 1 << endl;
        return 0;
    }
    ll cnt = 0;
    for (int i = 1; i <= n; i++)
    {
        vector<pair<int, ll>> v = ex(i);
        if (v.size() < 3)
            continue;
        if (v[0].X * (v[1].X + v[2].X) != ba)
            continue;
        int a = v[0].X;
        int b = v[1].X;
        int c = v[2].X;
        if (a == b && b == c)
        {
            ll sum1 = 0, sum2 = 0;
            ll ans = 0;
            for (int i = 0; i < v.size(); i++)
            {
                if (v[i].X != v[0].X)
                    break;
                sum1 += v[i].Y;
                sum2 += v[i].Y * v[i].Y;
            }
            ans = sum1 * sum1 - sum2;
            ans /= 2;
            cnt += ans;
        }
        else if (a != b && b == c)
        {
            ll sum1 = 0, sum2 = 0;
            ll ans = 0;
            for (int i = 1; i < v.size(); i++)
            {
                if (v[i].X != v[1].X)
                    break;
                sum1 += v[i].Y;
                sum2 += v[i].Y * v[i].Y;
            }
            ans = sum1 * sum1 - sum2;
            ans /= 2;
            cnt += ans;
        }
        else if (a == b && b != c)
        {
            ll sum = 0;
            for (int i = 2; i < v.size(); i++)
            {
                if (v[i].X != v[2].X)
                    break;
                sum += v[i].Y;
            }
            cnt += sum * (v[0].Y + v[1].Y);
        }
        else
        {
            ll sum = 0;
            for (int i = 2; i < v.size(); i++)
            {
                if (v[i].X != v[2].X)
                    break;
                sum += v[i].Y;
            }
            cnt += sum * (v[1].Y);
        }
    }
    cout << ba << ' ' << cnt << endl;
    return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...