Submission #1311703

#TimeUsernameProblemLanguageResultExecution timeMemory
1311703thelegendary08Road Closures (APIO21_roads)C++20
100 / 100
362 ms128464 KiB
#include "roads.h" #include<bits/stdc++.h> #define int long long #define mp make_pair #define eb emplace_back #define pb push_back #define f0r(i,n) for(int i = 0; i < n; i++) #define FOR(i,k,n) for(int i = k; i < n; i++) #define vi vector<int> #define vvi vector<vi> #define dout(x) cout<<x<<' '<<#x<<endl; #define dout2(x,y) cout<<x<<' '<<#x<<' '<<y<<' '<<#y<<endl; #define vout(v) for(auto u : v)cout<<u<<' '; cout<<endl; using namespace std; const int mxn = 1e5 + 5; struct Edge{ int u,v,w; }; struct segtree{ int n; vi sum, cnt, mn; segtree(){} segtree(int x){ n=x; sum.resize(4*n+5); cnt.resize(4*n+5); mn.resize(4*n+5,4e18); } void pull(int v){ sum[v]=sum[v*2]+sum[v*2+1],cnt[v]=cnt[v*2]+cnt[v*2+1],mn[v]=min(mn[v*2],mn[v*2+1]); } void update(int v, int tl, int tr, int k, int x){ if(tl==tr){sum[v]=mn[v]=x,cnt[v]=1; return;} int tm = tl+tr>>1; if(k <= tm){ update(v*2,tl,tm,k,x); } else update(v*2+1,tm+1,tr,k,x); pull(v); } pair<int,int> quer(int v, int tl, int tr, int k){ //sum of first k, value of k-th if(k<=0)return mp(0,0); if(tl==tr)return mp(sum[v], sum[v]); int tm = tl+tr>>1; if(cnt[v*2] >= k)return quer(v*2,tl,tm,k); pair<int,int> ret = quer(v*2+1,tm+1,tr,k-cnt[v*2]); return mp(ret.first+sum[v*2], ret.second); } void upd(int k, int x){ update(1,0,n-1,k,x); } pair<int,int>ask(int k){return quer(1,0,n-1,k);} pair<int,int>tp(){return mp(sum[1], cnt[1]);} }; set<pair<int,int>>adj[mxn]; vector<pair<int,int>>G[mxn]; segtree T[mxn]; int n, k; int dp[mxn][2]; multiset<int, greater<int>>dif[mxn]; bool vis[mxn]; set<int>imp; void dfs(int node, int from){ vis[node]=1; int sum = 0; vi tmp; for(auto [u,w] : adj[node])if(u!=from)dfs(u,node), sum+=max(dp[u][0],dp[u][1]), tmp.pb(dp[u][0]+w-max(dp[u][0],dp[u][1])); // if(k==1){vout(tmp); dout(dp[2][0])} sort(tmp.rbegin(),tmp.rend()); int sz = T[node].tp().second; int lo = max(0LL,k - sz), hi = min(k,(int)tmp.size()); while(lo < hi){ int mid = lo + (hi - lo + 1)/2; //can the first mid in tmp fit in the first k? if(tmp[mid-1] >= T[node].ask(k-mid+1).second)lo=mid; else hi=mid-1; } dp[node][1]=sum + T[node].ask(k-lo).first; f0r(i,lo)dp[node][1]+=tmp[i]; pair<int,int> tt = T[node].tp(); int cz = tt.second, cs = tt.first; f0r(i,tmp.size()){ if(tmp[i]>=0)cs+=tmp[i], cz++; } dp[node][0]=sum; if(cz < k-1)dp[node][0]+=cs; else{ int sz = T[node].tp().second; int lo = max(0LL,k - 1 - sz), hi = min(k-1,(int)tmp.size()); while(lo < hi){ int mid = lo + (hi - lo + 1)/2; //can the first mid in tmp fit in the first k-1? if(tmp[mid-1] >= T[node].ask(k-mid).second)lo=mid; else hi=mid-1; } dp[node][0] += T[node].ask(k-1-lo).first; f0r(i,lo)dp[node][0]+=tmp[i]; } // for(auto [u,w] : adj[node])if(u!=from)dif[node].erase(dif[node].find(dp[u][0]+w-max(dp[u][0],dp[u][1]))); } std::vector<long long> minimum_closure_costs(signed N, std::vector<signed> U, std::vector<signed> V, std::vector<signed> W) { n=N; vi ans; int S = 0; vi deg(n); vector<Edge>edges; f0r(i,n-1)edges.pb({U[i],V[i],W[i]}),adj[U[i]].insert(mp(V[i],W[i])),adj[V[i]].insert(mp(U[i],W[i])),G[U[i]].eb(V[i],W[i]),G[V[i]].eb(U[i],W[i]),S+=W[i],deg[U[i]]++,deg[V[i]]++; vvi w(n); f0r(i,n)w[deg[i]].pb(i); f0r(i,n)imp.insert(i); int per = 0; vi plus(n); vector<vector<Edge>>minus(n); for(auto [u,v,w] : edges){ plus[max(deg[u],deg[v])]+=w; minus[min(deg[u],deg[v])].pb({u,v,w}); } map<pair<int,int>,int>pos; f0r(i,n){ T[i]=segtree(deg[i]); vector<pair<int,int>>tmp; for(auto [v,w] : adj[i])tmp.eb(w,v); sort(tmp.rbegin(),tmp.rend()); f0r(j,tmp.size()){ pos[mp(i,tmp[j].second)]=j; } } f0r(K,n){ k=K; per+=plus[k]; vi tmp; for(auto u : w[k]){ imp.erase(u); //tmp.pb(u); } for(auto [u,v,w] : minus[k]){ if(deg[u] < deg[v])T[v].upd(pos[mp(v,u)],w); if(deg[v] < deg[u])T[u].upd(pos[mp(u,v)],w); // if(deg[u] < deg[v])dif[v].insert(w); if(deg[v] < deg[u])dif[u].insert(w); adj[u].erase(mp(v,w)), adj[v].erase(mp(u,w)); } //vout(imp); for(auto i : imp)f0r(j,2)dp[i][j]=0, vis[i]=0; int cur = S-per; //dout(per); for(auto i : imp)if(!vis[i])dfs(i,-1), cur -= max(dp[i][0],dp[i][1]); ans.pb(cur); } return ans; }
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...