This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include "closing.h"
#include <bits/stdc++.h>
#define all(v) v.begin(), v.end()
#define pb push_back
#define lb lower_bound
#define gibon ios::sync_with_stdio(false); cin.tie(0);
#define fi first
#define se second
#define pii pair<int, int>
#define pll pair<ll, ll>
typedef long long ll;
using namespace std;
const int mxN=200010;
const int mxM=200100;
const int mxK=61;
const int MOD=1e9;
const ll INF=2e18;
int N, X, Y;
ll K;
vector <pll> v[mxN];
ll d[mxN], D[mxN];
struct cmpd{
bool operator()(const int a, const int b) const {
if(d[a]!=d[b]) return d[a]<d[b];
return a<b;
}
};
struct cmpD{
bool operator()(const int a, const int b) const {
if(D[a]!=D[b]) return D[a]<D[b];
return a<b;
}
};
struct cmpDd{
bool operator()(const int a, const int b) const {
if(D[a]-d[a]!=D[b]-d[b]) return D[a]-d[a]<D[b]-d[b];
return a<b;
}
};
set <int, cmpd> s0d, s1d;
set <int, cmpD> s0D;
set <int, cmpDd> s1Dd, s2Dd;
void input(){
cin >> N >> X >> Y >> K;
X++, Y++;
for(int i=1;i<N;i++){
int a, b, c;
cin >> a >> b >> c;
a++; b++;
v[a].emplace_back(b, c);
v[b].emplace_back(a, c);
}
}
int par[mxN][20];
ll dep[mxN], depp[mxN];//dep: 거리, depp: 간선 개수
void init(){
for(int i=1;i<=N;i++) v[i].clear();
for(int i=1;i<=N;i++) d[i]=D[i]=0;
for(int i=1;i<=N;i++) for(int j=0;j<20;j++) par[i][j]=0;
for(int i=1;i<=N;i++) dep[i]=depp[i]=0;
s0d.clear(), s1d.clear(), s0D.clear(), s1Dd.clear(), s2Dd.clear();
}
void dfs(int now, int pre){
for(auto [nxt, x] : v[now]) if(nxt!=pre){
dep[nxt]=dep[now]+x;
depp[nxt]=depp[now]+1;
par[nxt][0]=now;
dfs(nxt, now);
}
}
void make_sps(){
for(int i=1;i<20;i++) for(int j=1;j<=N;j++){
par[j][i]=par[par[j][i-1]][i-1];
}
}
int lca(int a, int b){
if(depp[a]<depp[b]) swap(a, b);
for(int i=19;i>=0;i--) if(depp[a]>=depp[b]+(1<<i)) a=par[a][i];
if(a==b) return a;
for(int i=19;i>=0;i--) if(par[a][i]!=par[b][i]) a=par[a][i], b=par[b][i];
return par[a][0];
}
ll dist(int a, int b){
int c=lca(a, b);
return dep[a]+dep[b]-2*dep[c];
}
ll solv1(){
vector <ll> ct;
for(int i=1;i<=N;i++) ct.push_back(d[i]);
sort(all(ct));
ll sum=0, cnt=0;
for(ll x : ct) if(sum+x<=K) sum+=x, cnt++;
return cnt;
}
ll f01(bool del){
if(s0d.empty()) return INF;
int now=*s0d.begin();
if(!del) return d[now];
s0d.erase(now), s0D.erase(now);
s1d.insert(now), s1Dd.insert(now);
return 0;
}
ll f12(bool del){
if(s1Dd.empty()) return INF;
int now=*s1Dd.begin();
if(!del) return D[now]-d[now];
s1d.erase(now), s1Dd.erase(now);
s2Dd.insert(now);
return 0;
}
ll f1002(bool del){
if(s1d.empty() || s0D.empty()) return INF;
int n1=*s1d.rbegin(), n2=*s0D.begin();
if(!del) return D[n2]-d[n1];
s1d.erase(n1), s1Dd.erase(n1);
s0d.insert(n1), s0D.insert(n1);
s0d.erase(n2), s0D.erase(n2);
s2Dd.insert(n2);
return 0;
}
ll f2102(bool del){
if(s2Dd.empty() || s0D.empty()) return INF;
int n1=*s2Dd.rbegin(), n2=*s0D.begin();
if(!del) return D[n2]-(D[n1]-d[n1]);
s2Dd.erase(n1);
s1d.insert(n1), s1Dd.insert(n1);
s0d.erase(n2), s0D.erase(n2);
s2Dd.insert(n2);
return 0;
}
ll solv2(){
ll sum=0, cnt=0;
ll dXY=dist(X, Y);
for(int i=1;i<=N;i++){
if(d[i]+D[i]==dXY){
sum+=d[i];
D[i]-=d[i];
d[i]=0;
}
}
//for(int i=1;i<=N;i++) printf("d[%d]=%lld, D[%d]=%lld\n", i, d[i], i, D[i]);
if(sum>K) return 0;
//printf("sum=%lld\n", sum);
for(int i=1;i<=N;i++) s0d.insert(i), s0D.insert(i);
while(true){
if(s0d.empty() && s1d.empty()) break;
/*printf("s0d\n");
for(int x : s0d) printf("%d ", x);
printf("\ns1d\n");
for(int x : s1Dd) printf("%d ", x);
printf("\ns2d\n");
for(int x : s2Dd) printf("%d ", x);
printf("\n");
*/
ll val1=f01(false), val2=f12(false), val3=f1002(false), val4=f2102(false);
ll val=min(min(min(val1, val2), val3), val4);
//printf("val=%lld\n", val);
if(sum+val>K) break;
sum+=val, cnt++;
if(val==val1) f01(true);
else if(val==val2) f12(true);
else if(val==val3) f1002(true);
else f2102(true);
}
return cnt;
}
int solve(){
dfs(1, -1);
//for(int i=1;i<=N;i++) printf("par[%d]=%d\n", i, par[i][0]);
make_sps();
for(int i=1;i<=N;i++) d[i]=dist(X, i), D[i]=dist(Y, i);
for(int i=1;i<=N;i++) if(d[i]>D[i]) swap(d[i], D[i]);
ll res1=solv1(), res2=solv2();
//printf("res1=%lld, res2=%lld\n", res1, res2);
return max(res1, res2);
}
int max_score(int n, int x, int y, long long k,
std::vector<int> U, std::vector<int> V, std::vector<int> W)
{
N=n, X=x, Y=y, K=k;
X++, Y++;
for(int i=0;i<N-1;i++){
v[U[i]+1].emplace_back(V[i]+1, W[i]);
v[V[i]+1].emplace_back(U[i]+1, W[i]);
}
ll ans=solve();
init();
return ans;
}
/*
int main()
{
gibon
input();
cout << solve();
}
*/
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |