Submission #1157342

#TimeUsernameProblemLanguageResultExecution timeMemory
1157342nh0902Hard route (IZhO17_road)C++20
0 / 100
1 ms2632 KiB
#include <bits/stdc++.h>
using namespace std;

#define ll long long

const int N = 1e5 + 10;

int n, k;

vector<int> g[N];

int pa[N];

long long d_pa[N];

pair<int, long long> md[N][3];

long long ans;

int root;

void dfs(int u, int p) {
    pa[u] = p;

    for (int i = 0; i < 3; i ++) md[u][i] = {u, 0};

    for (int v : g[u]) if (v != p) {
        dfs(v, u);

        //cout << u << " - " << v << " : " << md[v][0].second << "\n";

        int id = 3;

        for (int i = 0; i < 3; i ++) {
            if (md[u][i].second < md[v][0].second + 1) {
                id = i;
                break;
            }
        }

        //cout << u << " - " << v << " : " << id << " , " << md[v][0].second + 1 << "\n";

        if (id < 3) {
            for (int i = 2; i > id; i --) {
                md[u][i] = md[u][i - 1];
            }
            md[u][id] = {v, md[v][0].second + 1};
            //cout << u << " " << id << " : " << v << " " << md[u][id].second << "\n";
        }
    }

    /*
    cout << u << "\n";
    for (int j = 0; j < 3; j ++) {
        cout << md[u][j].first << " " << md[u][j].second << "\n";
    }
    cout << "\n";
    */
}

void dfs2(int u, int p) {
    for (int v : g[u]) if (v != p) {
        if (v == md[u][0].first) {
            d_pa[v] = max(d_pa[u] + 1, md[u][1].second + 1);
        } else {
            d_pa[v] = max(d_pa[u] + 1, md[u][0].second + 1);
        }

        dfs2(v, u);
    }
}

void solve() {
    dfs(1, 1);
    dfs2(1, 1);

    for (int i = 1; i <= n; i ++) {

        //cout << i << "\n";
        for (int j = 0; j < 3; j ++) {
            //cout << md[i][j].first << " " << md[i][j].second << "\n";
        }
        //cout << "\n";

        long long cur = 0;

        if (md[i][1].second == 0 || (md[i][2].second == 0 && d_pa[i] == 0)) continue;
        cur = max(cur, md[i][0].second * (md[i][1].second + md[i][2].second));
        cur = max(cur, md[i][0].second * (md[i][1].second + d_pa[i]));
        cur = max(cur, d_pa[i] * (md[i][0].second + md[i][1].second));

        if (cur >= ans) {
            root = i;
            ans = cur;
        }
    }
}

int sz[N];

long long cnt[N];

long long total;

void predfs(int u, int p) {
    sz[u] = 1;
    for (int v : g[u]) if (v != p) {
        predfs(v, u);
        sz[u] += sz[v];
    }
    //cout << u << " : " << sz[u] << "\n";
}

void update(int u, int p, int cur_d, int val) {
    if (sz[u] == 1) cnt[cur_d] += val;

    for (int v : g[u]) if (v != p) {
        update(v, u, cur_d + 1, val);
    }
}

void dfs3(int u, int p, int cur_d) {

    if (sz[u] == 1) {
        cnt[cur_d] ++;
        return;
    }

    int b = 0;
    for (int v : g[u]) if (v != p) {
        if (b == 0 || sz[v] > sz[b]) b = v;
    }

    for (int v : g[u]) if (v != p && v != b) {
        dfs3(v, u, cur_d + 1);
        update(v, u, cur_d + 1, - 1);
    }

    dfs3(b, u, cur_d + 1);

    for (int v : g[u]) if (v != p && v != b) {
        update(v, u, cur_d + 1, 1);
    }

    /*
    cout << u << "\n";
    for (int i = 0; i <= n; i ++) {
        cout << i << " : " << cnt[i] << "\n";
    }
    cout << "\n";
    */

    long long x[3];
    for (int i = 0; i < 3; i ++) x[i] = md[u][i].second;

    if (x[0] * (x[1] + x[2]) < ans && d_pa[u] * (x[0] + x[1]) < ans) return;
    else if (d_pa[u] * (x[0] + x[1]) < ans) {
        if (x[1] == x[0]) {
            if (x[2] == x[1]) {
                total += (cnt[x[1] + cur_d] - 1) * (cnt[x[1] + cur_d] - 2) / 2;
            } else {
                total += (cnt[x[1] + cur_d] - 1) * cnt[x[2] + cur_d];
            }
        } else {
            if (x[2] == x[1]) {
                total += (cnt[x[1] + cur_d] - 1) * cnt[x[1] + cur_d] / 2;
            } else {
                total += cnt[x[1] + cur_d] * cnt[x[2] + cur_d];
            }
        }

    } else {
        if (x[1] == x[0]) {
            total += cnt[x[0] + cur_d] * (cnt[x[0] + cur_d] - 1) / 2;
        } else {
            total += cnt[x[0] + cur_d] * cnt[x[1] + cur_d];
        }
    }

}

int main() {
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);

    cin >> n;

    int u, v;

    for (int i = 1; i < n; i ++) {
        cin >> u >> v;
        g[u].push_back(v);
        g[v].push_back(u);
    }

    root = 1;
    solve();
    cout << ans << "\n";
    //cout << root << "\n";
    if (ans == 0) {
        cout << 1 << "\n";
        return 0;
    }
    predfs(root, root);
    dfs3(root, root, 0);
    cout << total << "\n";
}




#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...