제출 #846518

#제출 시각아이디문제언어결과실행 시간메모리
846518onepunchac168봉쇄 시간 (IOI23_closing)C++17
100 / 100
394 ms53320 KiB
#include "closing.h" // do not change this

#include <bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define pb push_back
typedef long long ll;
typedef pair <ll,ll > ii;
typedef pair <ii,ll> iii;
const char nl= '\n';
const int Na=2e5+5;
int n,x,y;
ll k;
vector <ii> vt[Na];
ll res=0;
ll dp[Na][2];
bool cp[Na];
void dijkstra(int u,int dd)
{
    priority_queue<ii,vector <ii>,greater <ii> > qu;
    for (int i=0; i<n; i++)
    {
        dp[i][dd]=1e18+5;
        cp[i]=0;
    }
    dp[u][dd]=0;
    qu.push({0,u});
    while (!qu.empty())
    {
        ii aa=qu.top();
        qu.pop();
        if (cp[aa.se]==1)
        {
            continue;
        }
        cp[aa.se]=1;
        for (auto v:vt[aa.se])
        {
            if (dp[v.fi][dd]>dp[aa.se][dd]+v.se)
            {
                dp[v.fi][dd]=dp[aa.se][dd]+v.se;
                qu.push({dp[v.fi][dd],v.fi});
            }
        }
    }
}
void solve1()
{
    dijkstra(x,0);
    dijkstra(y,1);
    vector <ll> cnt;
    for (int i=0; i<n; i++)
    {
        cnt.pb(min(dp[i][0],dp[i][1]));
    }
    ll ans=0;
    ll dem=0;
    sort (cnt.begin(),cnt.end());
    for (auto v:cnt)
    {
        if (dem+v<=k)
        {
            dem+=v;
            ans++;
        }
        else break;
    }
    res=ans;
}
bool check[Na];
void dfs(int u,int vv)
{
    if (u==y)
    {
        check[u]=1;
        return;
    }
    for (auto v:vt[u])
    {
        if (v.fi==vv)
        {
            continue;
        }
        dfs(v.fi,u);
        if (check[v.fi]==1)
        {
            check[u]=1;
        }
    }
}
/*const int M=3005;
ll dpa[M][M*2];
void solve2()
{
    for (int i=0;i<n;i++)
    {
        check[i]=0;
    }
    dfs(x,-1);
    for (int i=0;i<n;i++)
    {
        for (int j=0;j<=2*n;j++)
        {
            dpa[i][j]=1e18+5;
        }
    }
    if (check[0]==1)
    {
        dpa[0][1]=min(dp[0][1],dp[0][0]);
        dpa[0][2]=max(dp[0][1],dp[0][0]);
    }
    else
    {
        dpa[0][0]=0;
        dpa[0][1]=min(dp[0][1],dp[0][0]);
        dpa[0][2]=max(dp[0][1],dp[0][0]);
    }
    for (int i=1;i<n;i++)
    {
        for (int j=0;j<=2*(i+1);j++)
        {
            if (check[i]==1)
            {
                if (j>=1){
                    dpa[i][j]=min(dpa[i][j],min(dp[i][0],dp[i][1])+dpa[i-1][j-1]);
                }
                if (j>=2)
                {
                    dpa[i][j]=min(dpa[i][j],max(dp[i][0],dp[i][1])+dpa[i-1][j-2]);
                }
            }
            else
            {
                dpa[i][j]=dpa[i-1][j];
                if (j>=1){
                    dpa[i][j]=min(dpa[i][j],min(dp[i][0],dp[i][1])+dpa[i-1][j-1]);
                }
                if (j>=2)
                {
                    dpa[i][j]=min(dpa[i][j],max(dp[i][0],dp[i][1])+dpa[i-1][j-2]);
                }
            }
        }
    }
    for (int j=0;j<=2*n;j++)
    {
        //cout<<j<<" "<<dpa[n][j]<<" "<<k<<nl;
        if (dpa[n-1][j]<=k)
        {
            res=max(res,j);
        }
    }
}*/
ii T[8*Na];
void build (int node,int l,int r)
{
    if (l==r)
    {
        T[node]= {0,0};
        return;
    }
    build(node*2,l,(l+r)/2);
    build(node*2+1,(l+r)/2+1,r);
    T[node]= {0,0};
}
void update(int node,int l,int r,int u,ii val)
{
    if (l>u||r<u)
    {
        return;
    }
    if (l==r)
    {
        T[node].fi+=val.fi;
        T[node].se+=val.se;
        return;
    }
    update(node*2,l,(l+r)/2,u,val);
    update(node*2+1,(l+r)/2+1,r,u,val);
    T[node].fi=T[node*2].fi+T[node*2+1].fi;
    T[node].se=T[node*2].se+T[node*2+1].se;
}
ll sum;
ii query(int node,int l,int r,ll val)
{
    //cout<<l<<" "<<r<<" "<<T[node].se<<nl;
    if (l==r)
    {
        sum+=T[node].fi;
        return {T[node].se,l};
    }
    int mid=(l+r)/2;
    if (T[node*2].fi<val)
    {
        sum+=T[node*2].fi;
        ii bb=query(node*2+1,mid+1,r,val-T[node*2].fi);
        return {bb.fi+T[node*2].se,bb.se};
    }
    ii bb=query(node*2,l,mid,val);
    return bb;
}
void solve3()
{
    vector <ll> tmp;
    for (int i=0; i<n; i++)
    {
        check[i]=0;
    }
    dfs(x,-1);
    ll resa=0;
    ll dem=0;
    vector <ii> gg;
    for (int i=0; i<n; i++)
    {
        if (check[i]==1)
        {
            resa+=min(dp[i][0],dp[i][1]);
            tmp.pb(abs(dp[i][0]-dp[i][1]));
            dem++;
        }
        else
        {
            tmp.pb(min(dp[i][0],dp[i][1]));
            tmp.pb(abs(dp[i][0]-dp[i][1]));
            gg.pb({max(dp[i][0],dp[i][1]),i});
        }
    }
    sort (gg.begin(),gg.end());
    sort (tmp.begin(),tmp.end());
    tmp.resize(unique(tmp.begin(),tmp.end())-tmp.begin());
    build(1,1,tmp.size());
    if (resa>k)
    {
        return;
    }
    for (int i=0; i<n; i++)
    {
        if (check[i]==1)
        {
            ll aa=lower_bound(tmp.begin(),tmp.end(),abs(dp[i][0]-dp[i][1]))-tmp.begin()+1;
            update(1,1,tmp.size(),aa, {abs(dp[i][0]-dp[i][1]),1});
        }
        else
        {
            ll aa=lower_bound(tmp.begin(),tmp.end(),min(dp[i][0],dp[i][1]))-tmp.begin()+1;
            update(1,1,tmp.size(),aa, {min(dp[i][0],dp[i][1]),1});
        }
    }
    ll ans=0;
    {
        ll h2=0;
        sum=0;
        ii h1=query(1,1,tmp.size(),k-resa);
        h2=h1.fi;
        ll gg=tmp[h1.se-1];
        if (sum>k-resa)
        {
            ll rr=(sum-k+resa)/gg;
            if (rr*gg<sum-k+resa)
            {
                rr++;
            }
            h2-=rr;
        }
        res=max(res,ans+dem+h2);
    }
    for (auto v:gg)
    {
        ans++;
        ll aa=lower_bound(tmp.begin(),tmp.end(),min(dp[v.se][0],dp[v.se][1]))-tmp.begin()+1;
        update(1,1,tmp.size(),aa, {-min(dp[v.se][0],dp[v.se][1]),-1});
        ll bb=lower_bound(tmp.begin(),tmp.end(),abs(dp[v.se][0]-dp[v.se][1]))-tmp.begin()+1;
        update(1,1,tmp.size(),bb, {abs(dp[v.se][0]-dp[v.se][1]),1});
        //cout<<v.se<<" "<<abs(dp[v.se][0]-dp[v.se][1])<<nl;
        resa+=min(dp[v.se][0],dp[v.se][1]);
        if (resa<=k)
        {
            ll h2=0;
            sum=0;
            ii h1=query(1,1,tmp.size(),k-resa);
            h2=h1.fi;
            ll gg=tmp[h1.se-1];
            if (sum>k-resa)
            {
                ll rr=(sum-k+resa)/gg;
                if (rr*gg<sum-k+resa)
                {
                    rr++;
                }
                h2-=rr;

            }
            res=max(res,ans+dem+h2);
            //cout<<sum-k+resa<<" "<<res<<" "<<ans+dem+h2<<" "<<ans<<" "<<dem<<" "<<h2<<" "<<k-resa<<nl;
        }
    }

}
ll dema=0;
int max_score(int N, int X, int Y, long long K, std::vector<int> U, std::vector<int> V, std::vector<int> W)
{
    n=N;
    x=X;
    y=Y;
    k=K;
    res=0;
    dema+=N;
    for (int i=0; i<n; i++)
    {
        vt[i].clear();
    }
    for (int i=0; i<n-1; i++)
    {
        vt[U[i]].pb({V[i],W[i]});
        vt[V[i]].pb({U[i],W[i]});
    }
    solve1();
    /*if (n<=3000&&dema<=3000){
        solve2();
    }*/
    solve3();
    return res;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...