제출 #1342009

#제출 시각아이디문제언어결과실행 시간메모리
1342009NValchanovHard route (IZhO17_road)C++20
100 / 100
804 ms129012 KiB
#include <iostream>
#include <vector>
#include <algorithm>
#include <cassert>

using namespace std;

typedef long long llong;

const int MAXN = 5e5 + 10;

struct Path
{
    int len;
    int cnt;

    Path(){};

    Path(int len, int cnt)
    {
        this->len = len;
        this->cnt = cnt;
    }

    inline friend bool operator<(const Path& p1, const Path& p2)
    {
        if(p1.len != p2.len)
            return p1.len > p2.len;

        return p1.cnt > p2.cnt;
    }
};

int n;
vector < int > adj[MAXN];
int farth[MAXN][2];
int cnt[MAXN][2];
pair < int, int > opt[MAXN][2];
pair < llong, llong > dp[MAXN];

void read()
{
    cin >> n;

    for(int i = 1; i < n; i++)
    {
        int u, v;
        cin >> u >> v;

        adj[u].push_back(v);
        adj[v].push_back(u);
    }
}

void dfs1(int u, int p)
{
    farth[u][0] = 0;
    cnt[u][0] = 1;

    for(int& v : adj[u])
    {
        if(v == p)
            continue;

        dfs1(v, u);

        if(farth[u][0] < farth[v][0] + 1)
        {
            farth[u][0] = farth[v][0] + 1;
            cnt[u][0] = cnt[v][0];
        }
        else if(farth[u][0] == farth[v][0] + 1)
        {
            cnt[u][0] += cnt[v][0];
        }
    }
}

void dfs2(int u, int p)
{
    vector < Path > paths;

    for(int& v : adj[u])
    {
        if(v == p)
            continue;

        paths.push_back(Path(farth[v][0] + 1, cnt[v][0]));
    }

    paths.push_back(Path(farth[u][1], cnt[u][1]));

    sort(paths.begin(), paths.end());

    if(adj[u].size() >= 3)
    {
        Path p1 = paths[0];
        Path p2 = paths[1];
        Path p3 = paths[2];

        dp[u].first = 1LL * p1.len * (p2.len + p3.len);

        llong sum = 0;
        llong sum2 = 0;

        for(Path& pt : paths)
        {
            if(pt.len == p3.len)
            {
                sum += pt.cnt;
                sum2 += (1LL * pt.cnt * pt.cnt);
            }
        }

        if(p1.len != p2.len && p2.len != p3.len)
        {
            dp[u].second = 1LL * p2.cnt * sum; 
        }
        else if(p1.len == p2.len && p2.len == p3.len)
        {
            dp[u].second = (1LL * sum * sum - sum2) / 2LL; 
        }
        else if(p1.len == p2.len)
        {
            dp[u].second = 1LL * p1.cnt * sum + 1LL * p2.cnt * sum;
        }
        else if(p2.len == p3.len)
        {
            dp[u].second = (1LL * sum * sum - sum2) / 2LL;
        }
        else
        {
            assert(false);
        }
    }

    opt[u][0] = opt[u][1] = {0, 1};

    for(int& v : adj[u])
    {
        if(v == p)
            continue;

        if(farth[v][0] + 2 > opt[u][0].first)
        {
            opt[u][1] = opt[u][0];
            opt[u][0] = {farth[v][0] + 2, cnt[v][0]};
        }
        else if(farth[v][0] + 2 == opt[u][0].first)
        {
            opt[u][0].second += cnt[v][0];
        }
        else if(farth[v][0] + 2 > opt[u][1].first)
        {
            opt[u][1] = {farth[v][0] + 2, cnt[v][0]};
        }
        else if(farth[v][0] + 2 == opt[u][1].first)
        {
            opt[u][1].second += cnt[v][0];
        }
    }

    if(farth[u][1] + 1 > opt[u][0].first)
    {
        opt[u][1] = opt[u][0];
        opt[u][0] = {farth[u][1] + 1, cnt[u][1]};
    }
    else if(farth[u][1] + 1 == opt[u][0].first)
    {
        opt[u][0].second += cnt[u][1];
    }
    else if(farth[u][1] + 1 > opt[u][1].first)
    {
        opt[u][1] = {farth[u][1] + 1, cnt[u][1]};
    }
    else if(farth[u][1] + 1 == opt[u][1].first)
    {
        opt[u][1].second += cnt[u][1];
    }

    for(int& v : adj[u])
    {
        if(v == p)
            continue;

        if(farth[v][0] + 2 == opt[u][0].first)
        {
            if(opt[u][0].second - cnt[v][0] == 0)
            {
                farth[v][1] = opt[u][1].first;
                cnt[v][1] = opt[u][1].second;
            }
            else
            {
                farth[v][1] = opt[u][0].first;
                cnt[v][1] = opt[u][0].second - cnt[v][0];
            }
        }
        else
        {
            farth[v][1] = opt[u][0].first;
            cnt[v][1] = opt[u][0].second;
        }

        dfs2(v, u);
    }
}

void solve()
{
    dfs1(1, 1);

    farth[1][1] = 0;
    cnt[1][1] = 1;
    dfs2(1, 1);

    pair < llong, llong > result = {0, 1};

    for(int i = 1; i <= n; i++)
    {
        if(dp[i].first > result.first)
        {
            result = dp[i];
        }
        else if(dp[i].first == result.first)
        {
            result.second += dp[i].second;
        }
    }

    cout << result.first << " " << result.second << endl;
}

int main()
{
    read();
    solve();

    return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...