제출 #1137837

#제출 시각아이디문제언어결과실행 시간메모리
1137837imarnHard route (IZhO17_road)C++20
0 / 100
5 ms12104 KiB
#include<bits/stdc++.h> #define ll long long #define pii pair<int,int> #define pll pair<ll,ll> #define plx pair<ll,int> #define f first #define s second #define pb push_back #define all(x) x.begin(),x.end() #define vi vector<int> #define vvi vector<vi> #define pp pair<ll,int> #define ub(x,i) upper_bound(all(x),i)-x.begin() #define lb(x,i) lower_bound(all(x),i)-x.begin() #define t3 tuple<int,int,int> using namespace std; const int mxn=5e5+5; vector<int>g[mxn]; struct node{ ll d,p,x; bool operator>(const node &o)const{ return d>o.d; } bool operator==(const node &o)const{ return d==o.d; } }; node dp[mxn][2]; node dp2[mxn]; void dfs(int u,int p){ dp[u][0]=dp[u][1]={0,-1,1}; for(auto v:g[u]){ if(v==p)continue; dfs(v,u); node tmp={dp[v][0].d+1,v,dp[v][0].x}; if(tmp.d>dp[u][0].d)swap(tmp,dp[u][0]); if(tmp.d>dp[u][1].d)swap(tmp,dp[u][1]); } } void dfs2(int u,int p){ dp2[u]=dp2[p];dp2[u].d++; node tmp=dp[p][0]; if(tmp.p==u)tmp=dp[p][1]; tmp.d++; if(tmp.d>dp2[u].d)swap(dp2[u],tmp); for(auto v:g[u]){ if(v==p)continue; dfs2(v,u); } } ll ans=-1; ll tt=0; pair<pll,ll> dfs3(int u,int p){ int mx=0;vector<pll>vec; vec.pb({0,1}); for(auto v:g[u]){ if(v==p)continue; pair<pll,ll> rs=dfs3(v,u); vec.pb(rs.f);dp2[u].d=max(dp2[u].d,rs.s); }sort(all(vec),greater<pii>()); if(vec.size()==3){ ll rs=(vec[0].f+vec[1].f)*dp2[u].d; if(rs>ans){ ans=rs; tt=vec[0].s*vec[1].s; }else if(rs==ans)tt+=vec[0].s*vec[1].s; } else if(vec.size()>3){ ll sm=0,sm2=0; for(int j=0;j<vec.size();j++)if(vec[j].f==vec[2].f)sm+=vec[j].s,sm2+=vec[j].s*vec[j].s; ll rs=(vec[2].f+vec[1].f)*vec[0].f; if(rs>ans)ans=rs,tt=0; if(vec[1].f==vec[2].f){ if(rs==ans)tt+=(sm*sm-sm2)/2; } else if(vec[0].f==vec[1].f&&vec[1].f>vec[2].f){ if(rs==ans)tt+=vec[0].s*sm+vec[1].s*sm; } else if(vec[0].f>vec[1].f&&vec[1].f>vec[2].f){ if(rs==ans)tt+=vec[1].s*sm; }sm=0,sm2=0; for(int j=0;j<vec.size();j++)if(vec[j].f==vec[1].f)sm+=vec[j].s,sm2+=vec[j].s*vec[j].s; rs=(vec[0].f+vec[1].f)*dp2[u].d; if(rs>ans)ans=rs,tt=0; if(vec[0].f>vec[1].f){ if(rs==ans)tt+=vec[0].s*sm; } if(vec[0].f==vec[1].f){ if(rs==ans)tt+=(sm*sm-sm2)/2; } }ll sm=0,mem=-1,sm2=0; for(int j=0;j<vec.size();j++){ if(vec[j].f==vec[0].f)sm+=vec[j].s; } if(vec.size()==1)return make_pair(make_pair(vec[0].f+1,sm),-1ll); else return make_pair(make_pair(vec[0].f+1,sm),vec[1].f+1); } int main(){ ios_base::sync_with_stdio(0);cin.tie(0); int n;cin>>n; for(int i=1;i<=n-1;i++){ int u,v;cin>>u>>v;g[u].pb(v);g[v].pb(u); }int rt=0;for(int i=1;i<=n;i++)if(g[i].size()>1)rt=i; dfs(rt,rt);dp2[rt]={0,0,1};for(auto v:g[rt])dfs2(v,rt); dfs3(rt,rt); cout<<ans<<' '<<tt; }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...