제출 #1328758

#제출 시각아이디문제언어결과실행 시간메모리
1328758vuton101982봉쇄 시간 (IOI23_closing)C++20
100 / 100
151 ms43144 KiB
#include "closing.h"
#include <bits/stdc++.h>
using namespace std;

using ll = long long;
using pii = array<ll,2>;

static inline int SZ(const auto &v){ return (int)v.size(); }

static pair<vector<ll>, vector<ll>> doConvex(vector<pii> p){
    // For items where gap >= cost/2 (convex behaviour)
    vector<ll> v;
    v.reserve(2*SZ(p));
    for (auto &ab: p){
        ll a = ab[0], b = ab[1];
        v.push_back(a);
        v.push_back(b - a);
    }
    sort(v.begin(), v.end(), greater<ll>());
    v.insert(v.begin(), 0);
    for(int i=1;i<SZ(v);i++) v[i]+=v[i-1];

    vector<ll> ans0, ans1;
    for(int i=0;i<SZ(v);i++){
        if((i+1)%2==0) ans0.push_back(v[i]);
        else ans1.push_back(v[i]);
    }
    return {ans0, ans1};
}

static pair<vector<ll>, vector<ll>> doConcave(vector<pii> p){
    // For items where gap < cost/2 (concave behaviour)
    vector<ll> v0, v1;
    sort(p.begin(), p.end(), [&](const pii& a, const pii& b){
        return a[1] > b[1];
    });

    vector<ll> curmin(SZ(p)+1);
    curmin[SZ(p)] = (ll)-2e18;
    for(int i=SZ(p)-1;i>=0;i--){
        curmin[i] = max(curmin[i+1], p[i][0]);
    }

    ll sum = 0;
    v1.push_back(0);
    for(int i=0;i<SZ(p);i++){
        sum += p[i][1];
        v1.push_back(sum);
    }

    sum = 0;
    ll bamax = (ll)-2e18;
    for(int i=0;i<SZ(p);i++){
        sum += p[i][1];
        bamax = max(bamax, p[i][0] - p[i][1]);
        v0.push_back(max(sum + bamax, sum - p[i][1] + curmin[i]));
    }
    return {v0, v1};
}

static vector<ll> solveCoins(vector<pii> inp){
    // Each off-path node v has costs:
    //  0pt: 0
    //  1pt: dx (min)
    //  2pt: dy (max)
    // We want minimal cost for each total points t (0..2n).
    int n = SZ(inp);
    ll sumDy = 0;

    vector<pii> convex, concave;
    for(auto &ab: inp){
        ll a = ab[0], b = ab[1]; // assume a<=b
        sumDy += b;
        ll gap = b - a;          // saving if downgrade 2pt->1pt
        // transform to (gap, b)
        if(gap * 2 >= b) convex.push_back({gap, b});
        else concave.push_back({gap, b});
    }

    auto [odd1, even1] = doConcave(concave);
    auto [odd2, even2] = doConvex(convex);

    auto conv = [&](const vector<ll>& a, const vector<ll>& b)->vector<ll>{
        if(a.empty()) return b;
        if(b.empty()) return a;
        vector<ll> dif;
        ll base = a[0] + b[0];
        for(int i=1;i<SZ(a);i++) dif.push_back(a[i]-a[i-1]);
        for(int i=1;i<SZ(b);i++) dif.push_back(b[i]-b[i-1]);
        sort(dif.begin(), dif.end(), greater<ll>());
        vector<ll> v;
        v.reserve(SZ(dif)+1);
        v.push_back(base);
        for(ll x: dif) v.push_back(v.back()+x);
        return v;
    };

    vector<ll> bestSave(2*n + 1, 0);
    {
        auto comb = conv(odd1, odd2);
        for(int i=0; i<SZ(comb) && 2*i+2 < SZ(bestSave); i++)
            bestSave[2*i+2] = max(bestSave[2*i+2], comb[i]);
    }
    {
        auto comb = conv(odd1, even2);
        for(int i=0; i<SZ(comb) && 2*i+1 < SZ(bestSave); i++)
            bestSave[2*i+1] = max(bestSave[2*i+1], comb[i]);
    }
    {
        auto comb = conv(even1, odd2);
        for(int i=0; i<SZ(comb) && 2*i+1 < SZ(bestSave); i++)
            bestSave[2*i+1] = max(bestSave[2*i+1], comb[i]);
    }
    {
        auto comb = conv(even1, even2);
        for(int i=0; i<SZ(comb) && 2*i < SZ(bestSave); i++)
            bestSave[2*i] = max(bestSave[2*i], comb[i]);
    }

    // base = choose 2pt for all => sumDy
    // cost(points) = sumDy - maxSaving(lostPoints)
    // We computed bestSave indexed by "lost points".
    // Convert to minimal cost indexed by "gained points" and in increasing order.
    for(auto &x: bestSave) x = sumDy - x;
    reverse(bestSave.begin(), bestSave.end()); // now index = gained points
    return bestSave; // size 2n+1, bestSave[t] = min cost to gain t points
}

static const int MAXN = 200000 + 5;
static vector<vector<pair<int,int>>> g;
static int parentArr[MAXN];
static ll distXY[2][MAXN];

static void dfs(int root, int p, int id){
    stack<int> st;
    st.push(root);
    parentArr[root] = p;
    while(!st.empty()){
        int x = st.top(); st.pop();
        for(auto [w,y]: g[x]){
            if(y==parentArr[x]) continue;
            parentArr[y]=x;
            distXY[id][y] = distXY[id][x] + (ll)w;
            st.push(y);
        }
    }
}

int max_score(int N, int X, int Y, long long K,
              std::vector<int> U, std::vector<int> V, std::vector<int> W) {

    g.assign(N, {});
    for(int i=0;i<N-1;i++){
        g[U[i]].push_back({W[i], V[i]});
        g[V[i]].push_back({W[i], U[i]});
    }

    for(int i=0;i<N;i++) distXY[0][i]=distXY[1][i]=0;

    // distXY[0] from X, distXY[1] from Y, and parentArr from dfs-root
    dfs(Y, -1, 1);
    dfs(X, -1, 0);

    // build path X->Y using parent from dfs(X)?? we used parentArr from dfs(X),
    // but parentArr currently corresponds to dfs(X) run, so chain from Y to X.
    vector<int> onPath(N, 0);
    vector<int> path;
    for(int v=Y; v!=X; v=parentArr[v]){
        path.push_back(v);
        onPath[v]=1;
    }
    path.push_back(X);
    reverse(path.begin(), path.end());
    onPath[X]=1;

    // Off-path items as (min(dx,dy), max(dx,dy))
    vector<pii> coins;
    coins.reserve(N);
    for(int i=0;i<N;i++){
        if(onPath[i]) continue;
        ll dx = distXY[0][i], dy = distXY[1][i];
        if(dx > dy) swap(dx, dy);
        coins.push_back({dx, dy});
    }

    // s1[t] = min cost to gain t points from off-path nodes
    vector<ll> s1 = solveCoins(coins);

    // s2[j-1] = min cost to make exactly j nodes on path type-2
    // (all path nodes are at least type-1; type-2 segment is contiguous on path)
    vector<ll> s2(SZ(path), (ll)2e18);
    {
        int m = SZ(path);
        vector<ll> saveLeft(m), saveRight(m);
        ll tot = 0;
        for(int i=0;i<m;i++){
            int v = path[i];
            ll dx = distXY[0][v], dy = distXY[1][v];
            ll mx = max(dx, dy);
            tot += mx;                 // baseline: all are type-2
            saveLeft[i]  = mx - dx;    // saving if downgraded to type-1 via X
            saveRight[i] = mx - dy;    // saving if downgraded to type-1 via Y
        }

        int l=0, r=m;
        while(l<r){
            int segLen = r-l;          // type-2 segment length
            s2[segLen-1] = min(s2[segLen-1], tot);
            // pick next downgrade on the side that saves more (greedy, but constrained to prefixes/suffixes)
            if(saveLeft[l] > saveRight[r-1]){
                tot -= saveLeft[l++];
            }else{
                tot -= saveRight[--r];
            }
        }
    }

    int ans = 0;

    // Combine off-path points t and path type-2 count j (>=1 if we use s2)
    int j = SZ(s2);
    for(int t=0; t<SZ(s1); t++){
        while(j>0 && s1[t] + s2[j-1] > K) j--;
        if(j>0){
            ans = max(ans, t + j + SZ(path)); // |path| baseline + j extra + t from off-path
        }
    }

    // Also consider solutions that don't force "all path nodes type-1" structure:
    // picking individually cheapest distances to X or Y (simple upper-bound baseline)
    {
        ll cursum = 0;
        vector<ll> allCosts;
        allCosts.reserve(2*N);
        for(int i=0;i<N;i++){
            allCosts.push_back(distXY[0][i]);
            allCosts.push_back(distXY[1][i]);
        }
        sort(allCosts.begin(), allCosts.end());
        for(int i=0;i<SZ(allCosts);i++){
            cursum += allCosts[i];
            if(cursum > K) break;
            ans = max(ans, i+1);
        }
    }

    return ans;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...