제출 #840979

#제출 시각아이디문제언어결과실행 시간메모리
840979Cookie봉쇄 시간 (IOI23_closing)C++17
43 / 100
989 ms63440 KiB
#include<bits/stdc++.h>
#include<fstream>
#include<closing.h>
using namespace std;
//ifstream fin("FEEDING.INP");
//ofstream fout("FEEDING.OUT");
#define sz(a) (int)a.size()
#define ll long long
#define pb push_back
#define forr(i, a, b) for(int i = a; i < b; i++)
#define dorr(i, a, b) for(int i = a; i >= b; i--)
#define ld long double
#define vt vector
#include<fstream>
#define fi first
#define se second
#define pll pair<ll, ll>
#define pii pair<int, int>
const int mxn = 2e5 + 5, inf = 1e9;
int n, x, y;
ll k;
ll costx[mxn + 1], costy[mxn + 1], u[mxn + 1], v[mxn + 1], w[mxn + 1];
ll prefx[mxn + 1], prefy[mxn + 1], prefmx[mxn + 1];
map<pair<int, int>, ll>mp;
vt<ll>comp;
int find(int x){
    return(lower_bound(comp.begin(), comp.end(), x) - comp.begin() + 1);
}
struct ST{
    ll st[4 * mxn + 1];
    void init(){
        for(int i = 1; i <= 8 * n; i++)st[i] = 0;
    }
    void upd(int nd, int l, int r, int id, ll v){
        if(id < l || id > r)return;
        if(l == r){
            assert(l == id);
            st[nd] += v;
            return;
        }
        int mid = (l + r) >> 1;
        upd(nd << 1, l, mid, id, v); upd(nd << 1 | 1, mid + 1, r, id, v);
        st[nd] = st[nd << 1] + st[nd << 1 | 1];
    }
    ll get(int nd, int l, int r, int ql, int qr){
        if(ql > r || qr < l)return(0);
        if(ql <= l && qr >= r){
            //cout << nd << "    " << st[nd] << " ";
            return(st[nd]);
        }
        int mid = (l + r) >> 1;
        return(get(nd << 1, l, mid, ql, qr) + get(nd << 1 | 1, mid + 1, r, ql, qr));
    }
    int kth(int nd, int l, int r, ll k){
        if(l == r)return(l - 1);
        int mid = (l + r) >> 1;
        if(st[nd << 1] >= k){
            return(kth(nd << 1, l, mid, k));
        }else{
            return(kth(nd << 1 | 1, mid + 1, r, k - st[nd << 1]));
        }
    }
};
ST sm, cnt;
int get(ll m){


        int last = sm.kth(1, 1, sz(comp) + 1, m), res = cnt.get(1, 1, sz(comp) + 1, 1, last);
    ll tot = sm.get(1, 1, sz(comp) + 1, 1, last);


    if(last == sz(comp))return(res);
    res += (m - tot) / comp[last];

    return(res);

}
int solve(){
    //cout << n << " " << x << " " << y << " ";
    comp.clear();
    ll pref = 0;
    costx[x] = costy[y] = 0;
    for(int i = x - 1; i >= 1; i--){
        pref += mp[{i, i + 1}];
        costx[i] = pref;
        comp.pb(costx[i]);
    }
    pref = 0;
    for(int i = x; i < n; i++){
        pref += mp[{i, i + 1}];
        costx[i + 1] = pref;
    }
    pref = 0;
    for(int i = y - 1; i >= 1; i--){
        pref += mp[{i, i + 1}];
        costy[i] = pref;
    }
    pref = 0;
    for(int i = y; i < n; i++){
        pref += mp[{i, i + 1}];
        costy[i + 1] = pref;
        comp.pb(costy[i + 1]);

    }
    sort(comp.begin(), comp.end());
    for(int i = 1; i <= n; i++){
        prefmx[i] = prefmx[i - 1] + max(costx[i], costy[i]);
        prefx[i] = prefx[i - 1] + costx[i];
        prefy[i] = prefy[i - 1] + costy[i];
    }
    vt<ll>comp2;
    for(int i = 1; i <= n; i++)comp2.pb(costx[i]);
    for(int i = 1; i <= n; i++)comp2.pb(costy[i]);
    sort(comp2.begin(), comp2.end());
    int ans = 0;
    ll curr = 0;
    for(int i = 0; i < sz(comp2); i++){
        if(curr + comp2[i] <= k){
            curr += comp2[i]; ans++;
        }
    }
    // iteralte by intersection?
    for(int l = 1; l <= y; l++){
        sm.init(); cnt.init();
        for(int i = min(l, x) - 1; i >= 1; i--){
            cnt.upd(1, 1, sz(comp) + 1, find(costx[i]), 1);
            sm.upd(1, 1, sz(comp) + 1, find(costx[i]), costx[i]);
        }
        for(int r = n; r >= max(l, x); r--){
            ll cost = 0;
            cost += prefmx[r] - prefmx[l - 1];
            //for(int i = x; i < l; i++)cost += costx[i];
            //for(int i = y; i > r; i--)cost += costy[i];
            if(x < l){
                cost += prefx[l - 1] - prefx[x - 1];
            }
            if(y > r){
                cost += prefy[y] - prefy[r];
            }

            /*
            vt<ll>comp;
            for(int i = min(x, l) - 1; i >= 1; i--)comp.pb(costx[i]);
            for(int i = max(y, r) + 1; i <= n; i++)comp.pb(costy[i]);
            sort(comp.begin(), comp.end());
            */
            if(cost <= k){
            int cand = r - min(l, x) + 1 + max(r, y) - l + 1;
            cand += get(k - cost);

            //k -= cost;
            //cand += get(min(x, l) - 1, max(y, r) + 1);


            //if(cand == 7)cout << min(x, l) - 1 << " " << max(y, r) + 1 << " ";
            ans = max(ans, cand);
            }
            if(r > y){
                cnt.upd(1, 1, sz(comp) + 1, find(costy[r]), 1);
                sm.upd(1, 1, sz(comp) + 1, find(costy[r]), costy[r]);
            }
        }
    }
    return(ans);
}
int mxres = 0;
vt<pll>adj[mxn + 1];
vt<ll>compdep;
void dfs(int s, int pre, ll dep){
    compdep.pb(dep);
    for(auto [v, w]: adj[s]){
        if(v != pre){
            dfs(v, s, dep + w);
        }
    }
}

int max_score(int N, int X, int Y, long long K,
              std::vector<int> U, std::vector<int> V, std::vector<int> W)
{
    mp.clear();
    n = N; x = ++X; y = ++Y; k = K;
    for(int i = 1; i <= n; i++)adj[i].clear();
    bool ok = 1;
    for(int i = 0; i < n - 1; i++){

        u[i] = ++U[i]; v[i] = ++V[i]; w[i] = W[i];
        if(u[i] != i + 1 || v[i] != i + 2)ok = 0;
        mp[{u[i], v[i]}] = mp[{v[i], u[i]}] = w[i];
        adj[u[i]].pb(make_pair(v[i], w[i])); adj[v[i]].pb(make_pair(u[i], w[i]));
    }
    if(ok)return solve();
    else{
        compdep.clear();
        dfs(x, -1, 0);
        dfs(y, -1, 0);
        sort(compdep.begin(), compdep.end());
        int res = 0;
        ll curr = 0;
        for(int i = 0; i < sz(compdep); i++){
            if(curr + compdep[i] <= k){
                curr += compdep[i]; res++;
            }
        }
        return(res);
    }
}
/*
int main()
{

    int Q;
    assert(1 == scanf("%d", &Q));

    std::vector<int> N(Q), X(Q), Y(Q);
    std::vector<long long> K(Q);
    std::vector<std::vector<int>> U(Q), V(Q), W(Q);

    for (int q = 0; q < Q; q++)
    {
        assert(4 == scanf("%d %d %d %lld", &N[q], &X[q], &Y[q], &K[q]));

        U[q].resize(N[q] - 1);
        V[q].resize(N[q] - 1);
        W[q].resize(N[q] - 1);
        for (int i = 0; i < N[q] - 1; ++i)
        {
            assert(3 == scanf("%d %d %d", &U[q][i], &V[q][i], &W[q][i]));
        }
    }
    fclose(stdin);

    std::vector<int> result(Q);
    for (int q = 0; q < Q; q++)
    {
        result[q] = max_score(N[q], X[q], Y[q], K[q], U[q], V[q], W[q]);
    }

    for (int q = 0; q < Q; q++)
    {
        printf("%d\n", result[q]);
    }
    fclose(stdout);

    return 0;
}
*/
/*
1
5 0 3 6
0 1 1
1 2 1
2 3 1
3 4 100
*/
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...