Submission #1137835

#TimeUsernameProblemLanguageResultExecution timeMemory
1137835imarnHard route (IZhO17_road)C++20
0 / 100
5 ms12244 KiB
#include<bits/stdc++.h>
#define ll long long
#define pii pair<int,int>
#define pll pair<ll,ll>
#define plx pair<ll,int>
#define f first
#define s second
#define pb push_back
#define all(x) x.begin(),x.end()
#define vi vector<int>
#define vvi vector<vi>
#define pp pair<ll,int>
#define ub(x,i) upper_bound(all(x),i)-x.begin()
#define lb(x,i) lower_bound(all(x),i)-x.begin()
#define t3 tuple<int,int,int>
using namespace std;
const int mxn=5e5+5;
vector<int>g[mxn];
struct node{
    ll d,p,x;
    bool operator>(const node &o)const{
        return d>o.d;
    }
    bool operator==(const node &o)const{
        return d==o.d;
    }
};
node dp[mxn][2];
node dp2[mxn];
void dfs(int u,int p){
    dp[u][0]=dp[u][1]={0,-1,1};
    for(auto v:g[u]){
        if(v==p)continue;
        dfs(v,u);
        node tmp={dp[v][0].d+1,v,dp[v][0].x};
        if(tmp>dp[u][0])swap(tmp,dp[u][0]);
        if(tmp==dp[u][0])dp[u][0].x+=tmp.x,tmp={-1,0,0};
        if(tmp>dp[u][1])swap(tmp,dp[u][1]);
        if(tmp==dp[u][1])dp[u][1].x+=tmp.x,tmp={-1,0,0};
        if(tmp>dp[u][1])swap(tmp,dp[u][1]);
    }
}
void dfs2(int u,int p){
    dp2[u]=dp2[p];dp2[u].d++;
    node tmp=dp[p][0];
    if(tmp.d==dp[u][0].d+1){
        tmp.x-=dp[u][0].x;
        if(tmp.x==0)tmp=dp[p][1];
    }tmp.d++;
    if(dp2[u]==tmp)dp2[u].x+=tmp.x;
    if(tmp>dp2[u])swap(dp2[u],tmp);
    for(auto v:g[u]){
        if(v==p)continue;
        dfs2(v,u);
    }
}
ll ans=-1;
ll tt=0;
pll dfs3(int u,int p){
    int mx=0;vector<pll>vec;
    vec.pb({0,1});
    for(auto v:g[u]){
        if(v==p)continue;
        pll rs=dfs3(v,u);
        vec.pb(rs);
    }sort(all(vec),greater<pii>());
    if(vec.size()==3){
        ll rs=(vec[0].f+vec[1].f)*dp2[u].d;
        if(rs>ans){
            ans=rs;
            tt=vec[0].s*vec[1].s;
        }else if(rs==ans)tt+=vec[0].s*vec[1].s;
    }
    else if(vec.size()>3){
        ll sm=0,sm2=0;
        for(int j=0;j<vec.size();j++)if(vec[j].f==vec[2].f)sm+=vec[j].s,sm2+=vec[j].s*vec[j].s;
        ll rs=(vec[2].f+vec[1].f)*vec[0].f;
        if(rs>ans)ans=rs,tt=0;
        if(vec[0].f==vec[1].f&&vec[1].f==vec[2].f){
            if(rs==ans)tt+=(sm*sm-sm2)/2;
        }
        else if(vec[0].f>vec[1].f&&vec[1].f==vec[2].f){
            if(rs==ans)tt+=(sm*sm-sm2)/2;
        }
        else if(vec[0].f==vec[1].f&&vec[1].f>vec[2].f){
            if(rs==ans)tt+=vec[0].s*sm+vec[1].s*sm;
        }
        else if(vec[0].f>vec[1].f&&vec[1].f>vec[2].f){
            if(rs==ans)tt+=vec[1].s*sm;
        }sm=0,sm2=0;
        for(int j=0;j<vec.size();j++)if(vec[j].f==vec[1].f)sm+=vec[j].s,sm2+=vec[j].s*vec[j].s;
        rs=(vec[0].f+vec[1].f)*dp2[u].d;
        if(rs>ans)ans=rs,tt=0;
        if(vec[0].f>vec[1].f){
            if(rs==ans)tt+=vec[0].s*sm;
        }
        if(vec[0].f==vec[1].f){
            if(rs==ans)tt+=(sm*sm-sm2)/2;
        }
    }ll sm=0;
    for(int j=0;j<vec.size();j++)if(vec[j].f==vec[0].f)sm+=vec[j].s;
    return {vec[0].f+1,sm};
}
int main(){
    ios_base::sync_with_stdio(0);cin.tie(0);
    int n;cin>>n;
    for(int i=1;i<=n-1;i++){
        int u,v;cin>>u>>v;g[u].pb(v);g[v].pb(u);
    }int rt=0;for(int i=1;i<=n;i++)if(g[i].size()>1)rt=i;
    dfs(rt,rt);dp2[rt]={0,0,1};for(auto v:g[rt])dfs2(v,rt);
    dfs3(rt,rt);
    cout<<ans<<' '<<tt;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...