Submission #704158

#TimeUsernameProblemLanguageResultExecution timeMemory
704158TimDeePaths (RMI21_paths)C++17
56 / 100
201 ms24492 KiB
#include <bits/stdc++.h> using namespace std; #pragma GCC optimize("O3") #pragma GCC optimize("Ofast") #pragma GCC optimize("O2") using ll = long long; #define int long long #define forn(i,n) for(int i=0; i<(n); ++i) #define pb push_back #define pi pair<int,int> #define f first #define s second #define vii(a,n) vector<int> a(n); forn(i,n) cin>>a[i]; #define all(x) x.begin(), x.end() #define rall(x) x.rbegin(), x.rend() mt19937 rng(chrono::steady_clock::now().time_since_epoch().count()); const int inf = 1e15; const int mod = 998244353; const int n = 1e5; vector<int> qans(n,0); multiset<int> maxs, mins; int sum = 0; vector<pi> in(n,{0,0}),out(n,{0,0}); vector<vector<pi>> adj(n); void dfs(int u, int p) { in[u]={-1,u}; for (auto&e:adj[u]) { int v=e.f, w=e.s; if (v==p) continue; dfs(v,u); in[u]=max(in[u],{in[v].f+w,in[v].s}); } if (in[u].f==-1) in[u].f=0; } void dfs2(int u, int p, pi o) { if (u) out[u]=o; vector<pi> pr={{-1,-1}}; pi sf={-1,-1}; for (auto&e:adj[u]) { int v=e.f, w=e.s; if (v==p) { pr.pb(pr.back()); continue; } pr.pb(max(pr.back(),{in[v].f+w,in[v].s})); } int sz=adj[u].size(); for (int i=sz-1; i>=0; --i) { auto e=adj[u][i]; int v=e.f, w=e.s; if (v==p) continue; dfs2(v,u,max({ make_pair(o.f+w,o.s) , make_pair(sf.f+w,sf.s) , make_pair(pr[i].f+w,pr[i].s) })); sf=max(sf,{in[v].f+w,in[v].s}); } } vector<int> cnt(n,0); vector<vector<pi>> ans(n); vector<pi> mx(n); void cntdfs(int u,int p) { pi m={-1,-1}; for(auto&e:adj[u]) { int v=e.f,w=e.s; if (v==p) continue; cntdfs(v,u); if (ans[u].size()<ans[v].size()) swap(ans[u],ans[v]); if (mx[v].f+w>m.f) { if (m.s!=-1) ans[u].pb(m); m={mx[v].f+w,mx[v].s}; } else { ans[u].pb({mx[v].f+w,mx[v].s}); } for(auto&x:ans[v]) ans[u].pb(x); } if (m.s==-1) mx[u]={0,u}; else mx[u]=m; } int N; auto it = maxs.find(0); void reroot(int u, int p, int W) { bool Z=0,V=0,ZZ=0,VV=0,S=0,SS=0; int X=0,XX=0; if (u && N<=2000) { if (maxs.count(cnt[out[u].s])) { Z=1; it=maxs.find(cnt[out[u].s]); maxs.erase(it); cnt[out[u].s]+=W; maxs.insert(cnt[out[u].s]); sum+=W; } else { it=mins.find(-cnt[out[u].s]); mins.erase(it); cnt[out[u].s]+=W; if (cnt[out[u].s]>(*maxs.begin())) { V=1; int x=*maxs.begin(); X=x; maxs.erase(maxs.begin()); maxs.insert(cnt[out[u].s]); sum+=cnt[out[u].s]-x; mins.insert(-x); } else { mins.insert(-cnt[out[u].s]); } } if (maxs.count(cnt[in[u].s])) { ZZ=1; it=maxs.find(cnt[in[u].s]); maxs.erase(it); sum-=cnt[in[u].s]; cnt[in[u].s]-=W; if (cnt[in[u].s] >= -(*mins.begin())) { VV=1; maxs.insert(cnt[in[u].s]); sum+=cnt[in[u].s]; } else if (mins.size()) { S=1; int x=*mins.begin(); mins.erase(mins.begin()); XX=-x; sum+=-x; maxs.insert(-x); mins.insert(-cnt[in[u].s]); } } else { it=mins.find(-cnt[in[u].s]); mins.erase(it); cnt[in[u].s]-=W; mins.insert(-cnt[in[u].s]); } } qans[u]=sum; for (auto&e:adj[u]) { int v=e.f, w=e.s; if (v==p) continue; reroot(v,u,w); } if (u && N<=2000) { if (Z) { it=maxs.find(cnt[out[u].s]); maxs.erase(it); cnt[out[u].s]-=W; maxs.insert(cnt[out[u].s]); sum-=W; } else { if (V) { it=mins.find(-X); mins.erase(it); it=maxs.find(cnt[out[u].s]); maxs.erase(it); maxs.insert(X); sum-=cnt[out[u].s]-X; } else { it=mins.find(-cnt[out[u].s]); mins.erase(it); } cnt[out[u].s]-=W; mins.insert(-cnt[out[u].s]); } if (ZZ) { if (VV) { sum-=cnt[in[u].s]; it=maxs.find(cnt[in[u].s]); maxs.erase(it); } else if (S) { it=mins.find(-cnt[in[u].s]); mins.erase(it); it=maxs.find(XX); maxs.erase(it); sum-=XX; mins.insert(-XX); } cnt[in[u].s]+=W; sum+=cnt[in[u].s]; maxs.insert(cnt[in[u].s]); } else { it=mins.find(-cnt[in[u].s]); mins.erase(it); cnt[in[u].s]+=W; mins.insert(-cnt[in[u].s]); } } } void solve() { int n,k; cin>>n>>k; N=n; forn(i,n-1) { int u,v,w; cin>>u>>v>>w; --u, --v; adj[u].pb({v,w}); adj[v].pb({u,w}); } dfs(0,0); dfs2(0,0,{0,0}); cntdfs(0,0); ans[0].pb(mx[0]); if (adj[0].size()==1) ans[0].pb({0,0}); sort(rall(ans[0])); for (int i=0; i<ans[0].size(); ++i) { cnt[ans[0][i].s]=ans[0][i].f; if (i<k) { maxs.insert(ans[0][i].f); sum+=ans[0][i].f; } else mins.insert(-ans[0][i].f); } if (k>=ans[0].size()) { forn(i,n) cout<<sum<<'\n'; return; } reroot(0,0,0); forn(i,n) cout<<qans[i]<<'\n'; } int32_t main() { int t=1; //cin>>t; while (t--) solve(); return 0; }

Compilation message (stderr)

Main.cpp: In function 'void reroot(long long int, long long int, long long int)':
Main.cpp:91:29: warning: unused variable 'SS' [-Wunused-variable]
   91 |  bool Z=0,V=0,ZZ=0,VV=0,S=0,SS=0;
      |                             ^~
Main.cpp: In function 'void solve()':
Main.cpp:216:17: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<std::pair<long long int, long long int> >::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  216 |  for (int i=0; i<ans[0].size(); ++i) {
      |                ~^~~~~~~~~~~~~~
Main.cpp:224:7: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<std::pair<long long int, long long int> >::size_type' {aka 'long unsigned int'} [-Wsign-compare]
  224 |  if (k>=ans[0].size()) {
      |      ~^~~~~~~~~~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...