Submission #1094591

#TimeUsernameProblemLanguageResultExecution timeMemory
1094591azberjibiouClosing Time (IOI23_closing)C++17
100 / 100
758 ms73176 KiB
#include "closing.h"

#include <bits/stdc++.h>
#define all(v) v.begin(), v.end()
#define pb push_back
#define lb lower_bound
#define gibon ios::sync_with_stdio(false); cin.tie(0);
#define fi first
#define se second
#define pii pair<int, int>
#define pll pair<ll, ll>
typedef long long ll;
using namespace std;
const int mxN=200010;
const int mxM=200100;
const int mxK=61;
const int MOD=1e9;
const ll INF=2e18;
int N, X, Y;
ll K;
vector <pll> v[mxN];
ll d[mxN], D[mxN];
struct cmpd{
    bool operator()(const int a, const int b) const {
        if(d[a]!=d[b]) return d[a]<d[b];
        return a<b;
    }
};
struct cmpD{
    bool operator()(const int a, const int b) const {
        if(D[a]!=D[b]) return D[a]<D[b];
        return a<b;
    }
};
struct cmpDd{
    bool operator()(const int a, const int b) const {
        if(D[a]-d[a]!=D[b]-d[b]) return D[a]-d[a]<D[b]-d[b];
        return a<b;
    }
};
set <int, cmpd> s0d, s1d;
set <int, cmpD> s0D;
set <int, cmpDd> s1Dd, s2Dd;

void input(){
    cin >> N >> X >> Y >> K;
    X++, Y++;
    for(int i=1;i<N;i++){
        int a, b, c;
        cin >> a >> b >> c;
        a++; b++;
        v[a].emplace_back(b, c);
        v[b].emplace_back(a, c);
    }
}
int par[mxN][20];
ll dep[mxN], depp[mxN];//dep: 거리, depp: 간선 개수
void init(){
    for(int i=1;i<=N;i++) v[i].clear();
    for(int i=1;i<=N;i++) d[i]=D[i]=0;
    for(int i=1;i<=N;i++) for(int j=0;j<20;j++) par[i][j]=0;
    for(int i=1;i<=N;i++) dep[i]=depp[i]=0;
    s0d.clear(), s1d.clear(), s0D.clear(), s1Dd.clear(), s2Dd.clear();
}
void dfs(int now, int pre){
    for(auto [nxt, x] : v[now]) if(nxt!=pre){
        dep[nxt]=dep[now]+x;
        depp[nxt]=depp[now]+1;
        par[nxt][0]=now;
        dfs(nxt, now);
    }
}
void make_sps(){
    for(int i=1;i<20;i++) for(int j=1;j<=N;j++){
        par[j][i]=par[par[j][i-1]][i-1];
    }
}
int lca(int a, int b){
    if(depp[a]<depp[b]) swap(a, b);
    for(int i=19;i>=0;i--) if(depp[a]>=depp[b]+(1<<i)) a=par[a][i];
    if(a==b) return a;
    for(int i=19;i>=0;i--) if(par[a][i]!=par[b][i]) a=par[a][i], b=par[b][i];
    return par[a][0];
}
ll dist(int a, int b){
    int c=lca(a, b);
    return dep[a]+dep[b]-2*dep[c];
}
ll solv1(){
    vector <ll> ct;
    for(int i=1;i<=N;i++) ct.push_back(d[i]);
    sort(all(ct));
    ll sum=0, cnt=0;
    for(ll x : ct) if(sum+x<=K) sum+=x, cnt++;
    return cnt;
}
ll f01(bool del){
    if(s0d.empty()) return INF;
    int now=*s0d.begin();
    if(!del) return d[now];
    s0d.erase(now), s0D.erase(now);
    s1d.insert(now), s1Dd.insert(now);
    return 0;
}
ll f12(bool del){
    if(s1Dd.empty()) return INF;
    int now=*s1Dd.begin();
    if(!del) return D[now]-d[now];
    s1d.erase(now), s1Dd.erase(now);
    s2Dd.insert(now);
    return 0;
}
ll f1002(bool del){
    if(s1d.empty() || s0D.empty()) return INF;
    int n1=*s1d.rbegin(), n2=*s0D.begin();
    if(!del) return D[n2]-d[n1];
    s1d.erase(n1), s1Dd.erase(n1);
    s0d.insert(n1), s0D.insert(n1);
    s0d.erase(n2), s0D.erase(n2);
    s2Dd.insert(n2);
    return 0;
}
ll f2102(bool del){
    if(s2Dd.empty() || s0D.empty()) return INF;
    int n1=*s2Dd.rbegin(), n2=*s0D.begin();
    if(!del) return D[n2]-(D[n1]-d[n1]);
    s2Dd.erase(n1);
    s1d.insert(n1), s1Dd.insert(n1);
    s0d.erase(n2), s0D.erase(n2);
    s2Dd.insert(n2);
    return 0;
}
ll solv2(){
    ll sum=0, cnt=0;
    ll dXY=dist(X, Y);
    for(int i=1;i<=N;i++){
        if(d[i]+D[i]==dXY){
            sum+=d[i];
            D[i]-=d[i];
            d[i]=0;
        }
    }
    //for(int i=1;i<=N;i++) printf("d[%d]=%lld, D[%d]=%lld\n", i, d[i], i, D[i]);
    if(sum>K) return 0;
    //printf("sum=%lld\n", sum);
    for(int i=1;i<=N;i++) s0d.insert(i), s0D.insert(i);
    while(true){
        if(s0d.empty() && s1d.empty()) break;
        /*printf("s0d\n");
        for(int x : s0d) printf("%d ", x);
        printf("\ns1d\n");
        for(int x : s1Dd) printf("%d ", x);
        printf("\ns2d\n");
        for(int x : s2Dd) printf("%d ", x);
        printf("\n");
        */
        ll val1=f01(false), val2=f12(false), val3=f1002(false), val4=f2102(false);
        ll val=min(min(min(val1, val2), val3), val4);
        //printf("val=%lld\n", val);
        if(sum+val>K) break;
        sum+=val, cnt++;
        if(val==val1) f01(true);
        else if(val==val2) f12(true);
        else if(val==val3) f1002(true);
        else f2102(true);
    }
    return cnt;
}
int solve(){
    dfs(1, -1);
    //for(int i=1;i<=N;i++) printf("par[%d]=%d\n", i, par[i][0]);
    make_sps();
    for(int i=1;i<=N;i++) d[i]=dist(X, i), D[i]=dist(Y, i);
    for(int i=1;i<=N;i++) if(d[i]>D[i]) swap(d[i], D[i]);
    ll res1=solv1(), res2=solv2();
    //printf("res1=%lld, res2=%lld\n", res1, res2);
    return max(res1, res2);
}

int max_score(int n, int x, int y, long long k,
        std::vector<int> U, std::vector<int> V, std::vector<int> W)
{
    N=n, X=x, Y=y, K=k;
    X++, Y++;
    for(int i=0;i<N-1;i++){
        v[U[i]+1].emplace_back(V[i]+1, W[i]);
        v[V[i]+1].emplace_back(U[i]+1, W[i]);
    }
    ll ans=solve();
    init();
    return ans;
}
/*
int main()
{
    gibon
    input();
    cout << solve();
}
*/
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...