Submission #1094591

#TimeUsernameProblemLanguageResultExecution timeMemory
1094591azberjibiouClosing Time (IOI23_closing)C++17
100 / 100
758 ms73176 KiB
#include "closing.h" #include <bits/stdc++.h> #define all(v) v.begin(), v.end() #define pb push_back #define lb lower_bound #define gibon ios::sync_with_stdio(false); cin.tie(0); #define fi first #define se second #define pii pair<int, int> #define pll pair<ll, ll> typedef long long ll; using namespace std; const int mxN=200010; const int mxM=200100; const int mxK=61; const int MOD=1e9; const ll INF=2e18; int N, X, Y; ll K; vector <pll> v[mxN]; ll d[mxN], D[mxN]; struct cmpd{ bool operator()(const int a, const int b) const { if(d[a]!=d[b]) return d[a]<d[b]; return a<b; } }; struct cmpD{ bool operator()(const int a, const int b) const { if(D[a]!=D[b]) return D[a]<D[b]; return a<b; } }; struct cmpDd{ bool operator()(const int a, const int b) const { if(D[a]-d[a]!=D[b]-d[b]) return D[a]-d[a]<D[b]-d[b]; return a<b; } }; set <int, cmpd> s0d, s1d; set <int, cmpD> s0D; set <int, cmpDd> s1Dd, s2Dd; void input(){ cin >> N >> X >> Y >> K; X++, Y++; for(int i=1;i<N;i++){ int a, b, c; cin >> a >> b >> c; a++; b++; v[a].emplace_back(b, c); v[b].emplace_back(a, c); } } int par[mxN][20]; ll dep[mxN], depp[mxN];//dep: 거리, depp: 간선 개수 void init(){ for(int i=1;i<=N;i++) v[i].clear(); for(int i=1;i<=N;i++) d[i]=D[i]=0; for(int i=1;i<=N;i++) for(int j=0;j<20;j++) par[i][j]=0; for(int i=1;i<=N;i++) dep[i]=depp[i]=0; s0d.clear(), s1d.clear(), s0D.clear(), s1Dd.clear(), s2Dd.clear(); } void dfs(int now, int pre){ for(auto [nxt, x] : v[now]) if(nxt!=pre){ dep[nxt]=dep[now]+x; depp[nxt]=depp[now]+1; par[nxt][0]=now; dfs(nxt, now); } } void make_sps(){ for(int i=1;i<20;i++) for(int j=1;j<=N;j++){ par[j][i]=par[par[j][i-1]][i-1]; } } int lca(int a, int b){ if(depp[a]<depp[b]) swap(a, b); for(int i=19;i>=0;i--) if(depp[a]>=depp[b]+(1<<i)) a=par[a][i]; if(a==b) return a; for(int i=19;i>=0;i--) if(par[a][i]!=par[b][i]) a=par[a][i], b=par[b][i]; return par[a][0]; } ll dist(int a, int b){ int c=lca(a, b); return dep[a]+dep[b]-2*dep[c]; } ll solv1(){ vector <ll> ct; for(int i=1;i<=N;i++) ct.push_back(d[i]); sort(all(ct)); ll sum=0, cnt=0; for(ll x : ct) if(sum+x<=K) sum+=x, cnt++; return cnt; } ll f01(bool del){ if(s0d.empty()) return INF; int now=*s0d.begin(); if(!del) return d[now]; s0d.erase(now), s0D.erase(now); s1d.insert(now), s1Dd.insert(now); return 0; } ll f12(bool del){ if(s1Dd.empty()) return INF; int now=*s1Dd.begin(); if(!del) return D[now]-d[now]; s1d.erase(now), s1Dd.erase(now); s2Dd.insert(now); return 0; } ll f1002(bool del){ if(s1d.empty() || s0D.empty()) return INF; int n1=*s1d.rbegin(), n2=*s0D.begin(); if(!del) return D[n2]-d[n1]; s1d.erase(n1), s1Dd.erase(n1); s0d.insert(n1), s0D.insert(n1); s0d.erase(n2), s0D.erase(n2); s2Dd.insert(n2); return 0; } ll f2102(bool del){ if(s2Dd.empty() || s0D.empty()) return INF; int n1=*s2Dd.rbegin(), n2=*s0D.begin(); if(!del) return D[n2]-(D[n1]-d[n1]); s2Dd.erase(n1); s1d.insert(n1), s1Dd.insert(n1); s0d.erase(n2), s0D.erase(n2); s2Dd.insert(n2); return 0; } ll solv2(){ ll sum=0, cnt=0; ll dXY=dist(X, Y); for(int i=1;i<=N;i++){ if(d[i]+D[i]==dXY){ sum+=d[i]; D[i]-=d[i]; d[i]=0; } } //for(int i=1;i<=N;i++) printf("d[%d]=%lld, D[%d]=%lld\n", i, d[i], i, D[i]); if(sum>K) return 0; //printf("sum=%lld\n", sum); for(int i=1;i<=N;i++) s0d.insert(i), s0D.insert(i); while(true){ if(s0d.empty() && s1d.empty()) break; /*printf("s0d\n"); for(int x : s0d) printf("%d ", x); printf("\ns1d\n"); for(int x : s1Dd) printf("%d ", x); printf("\ns2d\n"); for(int x : s2Dd) printf("%d ", x); printf("\n"); */ ll val1=f01(false), val2=f12(false), val3=f1002(false), val4=f2102(false); ll val=min(min(min(val1, val2), val3), val4); //printf("val=%lld\n", val); if(sum+val>K) break; sum+=val, cnt++; if(val==val1) f01(true); else if(val==val2) f12(true); else if(val==val3) f1002(true); else f2102(true); } return cnt; } int solve(){ dfs(1, -1); //for(int i=1;i<=N;i++) printf("par[%d]=%d\n", i, par[i][0]); make_sps(); for(int i=1;i<=N;i++) d[i]=dist(X, i), D[i]=dist(Y, i); for(int i=1;i<=N;i++) if(d[i]>D[i]) swap(d[i], D[i]); ll res1=solv1(), res2=solv2(); //printf("res1=%lld, res2=%lld\n", res1, res2); return max(res1, res2); } int max_score(int n, int x, int y, long long k, std::vector<int> U, std::vector<int> V, std::vector<int> W) { N=n, X=x, Y=y, K=k; X++, Y++; for(int i=0;i<N-1;i++){ v[U[i]+1].emplace_back(V[i]+1, W[i]); v[V[i]+1].emplace_back(U[i]+1, W[i]); } ll ans=solve(); init(); return ans; } /* int main() { gibon input(); cout << solve(); } */
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...