Submission #839918

#TimeUsernameProblemLanguageResultExecution timeMemory
839918jtnydv25Closing Time (IOI23_closing)C++17
100 / 100
169 ms50000 KiB
#include "closing.h"

#include <bits/stdc++.h>
using namespace std;

#define ll long long
#define all(c) ((c).begin()), ((c).end())
#define sz(x) ((int)(x).size())

#ifdef LOCAL
#include <print.h>
#else
#define trace(...)
#define endl "\n" // remove in interactive
#endif
int max_score(int n, int x, int y, ll K, vector<int> U, vector<int> V, vector<int> W){
    #define int long long
    const ll INF = 2e18;
    vector<vector<int>> adj(n);
    for(int i = 0; i < n - 1; i++){
        adj[U[i]].push_back(i);
        adj[V[i]].push_back(i);
    }
    function<vector<ll>(int)> getDists = [&](int a){
        vector<int> vis(n);
        vector<ll> D(n);
        function<void(int, int)> dfs = [&](int s, int p){
            for(int i: adj[s]){
                int v = U[i] ^ V[i] ^ s;
                if(v == p) continue;
                int w = W[i];
                D[v] = D[s] + w;
                dfs(v, s);
            }
        };
        dfs(a, a);
        return D;
    };
    vector<ll> X = getDists(x), Y = getDists(y);
    vector<ll> lft, rgt1, rgt2;
    vector<bool> isInside(n);
    vector<int> inside;
    vector<ll> vals;
    for(int i = 0; i < n; i++){
        isInside[i] = (X[i] + Y[i] == X[y]);
        if(isInside[i]){
            inside.push_back(i);
        }
        vals.push_back({min(X[i], Y[i])});
    }
    sort(all(vals));
    ll v = 0, cnt = 0;
    for(ll v2: vals){
        if(v + v2 <= K){
            v += v2;
            cnt++;
        }
    }
    int num = inside.size();
    for(int i: inside){
        ll u = min(X[i],Y[i]);
        if(u > K) return cnt;
        K -= u;
    }
    
    
    for(int i = 0; i < n; i++){
        ll u = min(X[i], Y[i]), v = max(X[i], Y[i]);
        if(isInside[i]){
            lft.push_back(v - u);
            continue;
        }
        if(v >= 2 * u){
            lft.push_back(u);
            lft.push_back(v - u);
        } else{
            rgt1.push_back(u);
            rgt2.push_back(v);
        }
    }
    int k = sz(rgt1);
    vector<int> perml(sz(lft)), permr(k);
    sort(all(lft));
    iota(all(permr), 0);
    sort(all(permr), [&](int i, int j){return rgt2[i] < rgt2[j];});
    // trace(lft, rgt1, rgt2);
    vector<ll> suffixes(k + 1, INF);
    for(int i = k - 1; i >= 0; i--) suffixes[i] = min(suffixes[i + 1], rgt1[permr[i]]);
    vector<ll> minCost(2 * k + 1, INF);
    minCost[0] = 0;
    ll pref = 0;
    ll bestDiff = INF;
    for(int i = 0; i < k; i++){
        minCost[2 * i + 1] = min(minCost[2 * i + 1], pref + suffixes[i]);
        int pos  = permr[i];
        bestDiff = min(bestDiff, rgt1[pos] - rgt2[pos]);
        pref += rgt2[pos];
        minCost[2 * (i + 1)] = pref;
        minCost[2 * i + 1] = min(minCost[2 * i + 1], pref + bestDiff);
    }
    int lo = 0, hi = 2 * n - num;
    while(lo < hi){
        int mid = (lo + hi + 1) / 2;
        ll sum = 0;
        bool ok = false;
        for(int i = 0; i <= sz(lft) && i <= mid; i++){
            int rem = mid - i;
            
            if(rem <= 2 * k && sum + minCost[rem] <= K){
                ok = true;
            }
            if(i != sz(lft)) sum += lft[i];
        }
        if(ok) lo = mid;
        else hi = mid - 1;
    }
    return max(cnt, lo + num);
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...