Submission #1311698

#TimeUsernameProblemLanguageResultExecution timeMemory
1311698thelegendary08Road Closures (APIO21_roads)C++20
5 / 100
348 ms109800 KiB
#include "roads.h"
#include<bits/stdc++.h>
#define int long long
#define mp make_pair
#define eb emplace_back
#define pb push_back
#define f0r(i,n) for(int i = 0; i < n; i++)
#define FOR(i,k,n) for(int i = k; i < n; i++)
#define vi vector<int>
#define vvi vector<vi>
#define dout(x) cout<<x<<' '<<#x<<endl;
#define dout2(x,y) cout<<x<<' '<<#x<<' '<<y<<' '<<#y<<endl;
#define vout(v) for(auto u : v)cout<<u<<' '; cout<<endl;
using namespace std;
const int mxn = 1e5 + 5; 
struct Edge{
	int u,v,w;
};
struct segtree{
	int n; vi sum, cnt, mn; 
	segtree(){}
	segtree(int x){
		n=x; sum.resize(4*n+5); cnt.resize(4*n+5); mn.resize(4*n+5,4e18);
	}
	void pull(int v){
		sum[v]=sum[v*2]+sum[v*2+1],cnt[v]=cnt[v*2]+cnt[v*2+1],mn[v]=min(mn[v*2],mn[v*2+1]);
	}
	void update(int v, int tl, int tr, int k, int x){
		if(tl==tr){sum[v]=mn[v]=x,cnt[v]=1; return;}
		int tm = tl+tr>>1; if(k <= tm){
			update(v*2,tl,tm,k,x);  
		} else update(v*2+1,tm+1,tr,k,x);
		pull(v);
	}
	pair<int,int> quer(int v, int tl, int tr, int k){ //sum of first k, value of k-th
		if(tl==tr)return mp(sum[v], sum[v]); int tm = tl+tr>>1;
		if(cnt[v*2] >= k)return quer(v*2,tl,tm,k);
		pair<int,int> ret = quer(v*2+1,tm+1,tr,k-cnt[v*2]); 
		return mp(ret.first+sum[v*2], ret.second);
	}
	void upd(int k, int x){
		update(1,0,n-1,k,x);
	}
	pair<int,int>ask(int k){return quer(1,0,n-1,k);}
	pair<int,int>tp(){return mp(sum[1], cnt[1]);}
};
set<pair<int,int>>adj[mxn]; vector<pair<int,int>>G[mxn]; segtree T[mxn];
int n, k; int dp[mxn][2]; multiset<int, greater<int>>dif[mxn]; bool vis[mxn]; set<int>imp; 
void dfs(int node, int from){
	vis[node]=1; int sum = 0; vi tmp; 
	for(auto [u,w] : adj[node])if(u!=from)dfs(u,node), sum+=max(dp[u][0],dp[u][1]), tmp.pb(dp[u][0]+w-max(dp[u][0],dp[u][1])); 
	// if(k==1){vout(tmp); dout(dp[2][0])}
	sort(tmp.rbegin(),tmp.rend()); 
	int sz = T[node].tp().second; int lo = max(0LL,k - sz), hi = min(k,(int)tmp.size()); while(lo < hi){
		int mid = lo + (hi - lo + 1)/2; //can the first mid in tmp fit in the first k?
		if(tmp[mid-1] >= T[node].ask(k-mid+1).second)lo=mid; else hi=mid-1;
	}
	dp[node][1]=sum + T[node].ask(k-lo).first; f0r(i,lo)dp[node][1]+=tmp[i];
	pair<int,int> tt = T[node].tp(); int cz = tt.second, cs = tt.first; 
	f0r(i,tmp.size()){
		if(tmp[i]>=0)cs+=tmp[i], cz++; 
	}
	dp[node][0]=sum; if(cz < k-1)dp[node][0]+=cs; else{
		int sz = T[node].tp().second; int lo = max(0LL,k - sz), hi = min(k-1,(int)tmp.size()); while(lo < hi){
			int mid = lo + (hi - lo + 1)/2; //can the first mid in tmp fit in the first k-1?
			if(tmp[mid-1] >= T[node].ask(k-mid).second)lo=mid; else hi=mid-1;
		}
		dp[node][0] += T[node].ask(k-1-lo).first; f0r(i,lo)dp[node][0]+=tmp[i];
	}
	// for(auto [u,w] : adj[node])if(u!=from)dif[node].erase(dif[node].find(dp[u][0]+w-max(dp[u][0],dp[u][1]))); 
}
std::vector<long long> minimum_closure_costs(signed N, std::vector<signed> U,
                                             std::vector<signed> V,
                                             std::vector<signed> W) {
	n=N; vi ans; int S = 0; vi deg(n); vector<Edge>edges;
	f0r(i,n-1)edges.pb({U[i],V[i],W[i]}),adj[U[i]].insert(mp(V[i],W[i])),adj[V[i]].insert(mp(U[i],W[i])),G[U[i]].eb(V[i],W[i]),G[V[i]].eb(U[i],W[i]),S+=W[i],deg[U[i]]++,deg[V[i]]++; 
	vvi w(n); f0r(i,n)w[deg[i]].pb(i); f0r(i,n)imp.insert(i); int per = 0; vi plus(n); vector<vector<Edge>>minus(n); for(auto [u,v,w] : edges){
		plus[max(deg[u],deg[v])]+=w; minus[min(deg[u],deg[v])].pb({u,v,w});
	} 
	map<pair<int,int>,int>pos;
	f0r(i,n){
		T[i]=segtree(deg[i]); vector<pair<int,int>>tmp; for(auto [v,w] : adj[i])tmp.eb(w,v); sort(tmp.rbegin(),tmp.rend()); f0r(j,tmp.size()){
			pos[mp(i,tmp[j].second)]=j;
		}
	}
	f0r(K,n){
		k=K; per+=plus[k]; vi tmp; for(auto u : w[k]){
			imp.erase(u); //tmp.pb(u);  
		}
		for(auto [u,v,w] : minus[k]){
			if(deg[u] < deg[v])T[v].upd(pos[mp(v,u)],w); if(deg[v] < deg[u])T[u].upd(pos[mp(u,v)],w);
			// if(deg[u] < deg[v])dif[v].insert(w); if(deg[v] < deg[u])dif[u].insert(w);
			adj[u].erase(mp(v,w)), adj[v].erase(mp(u,w));
		}
		for(auto i : imp)f0r(j,2)dp[i][j]=0, vis[i]=0; int cur = S-per; //dout(per);
		for(auto i : imp)if(!vis[i])dfs(i,-1), cur -= max(dp[i][0],dp[i][1]); ans.pb(cur);
		
	} return ans;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...