Submission #1178710

#TimeUsernameProblemLanguageResultExecution timeMemory
1178710madamadam3Closing Time (IOI23_closing)C++20
9 / 100
1097 ms52940 KiB
#include "closing.h"
#include <bits/stdc++.h>

using namespace std;

typedef long long ll;
using vi = vector<int>;
using vvi = vector<vi>;
using vl = vector<ll>;
using pi = pair<int, int>;
using pli = pair<ll, int>;
using vb = vector<bool>;

int N, X, Y; ll K;
vi U, V; vl W;
vvi adj; map<pi, int> weights;

vl compute_distances(int start) {
    vl distances(N, 0LL);
    vi parent(N, -1);

    queue<int> q;
    q.push(start);

    while (!q.empty()) {
        int u = q.front();
        q.pop();

        for (int v : adj[u]) { 
            if (v == parent[u]) continue;
            
            q.push(v);
            distances[v] = distances[u] + weights[{u, v}];
            parent[v] = u;
        }
    }

    return distances;
}

struct BFSNode {
    int u, s; ll dist;

    BFSNode(int u, int s, ll dist) {
        this->u = u;
        this->s = s;
        this->dist = dist;
    }

    bool operator<(const BFSNode &other) const {
        return dist < other.dist;
    }

    bool operator>(const BFSNode &other) const {
        return dist > other.dist;
    }
};

int solve_disconnected() {
    vector<vb> vis(2, vb(N, false));
    vl tim(N+1, 0LL);

    int reachable = 0;
    ll cur_sum = 0;

    priority_queue<BFSNode, vector<BFSNode>, greater<BFSNode>> q;
    q.push(BFSNode(X, 0, 0LL));
    q.push(BFSNode(Y, 1, 0LL));

    while (!q.empty()) {
        BFSNode cur = q.top();
        int u = cur.u, s = cur.s;
        ll dist = cur.dist;
        q.pop();

        if (vis[s][u]) continue;
        vis[s][u] = true;

        // cout << "Node = " << u << " Source = " << s << " Dist = " << dist << "\n";
        // cout << "Current Sum = " << cur_sum << "\n";

        if (dist + cur_sum > K) continue;

        reachable++;
        cur_sum += dist;
        tim[u] += dist;

        for (int v : adj[u]) {
            if (vis[s][v]) continue;

            ll ndist = tim[u] + weights[{u, v}] - tim[v];
            // cout << "Neighbour = " << v << " Cost = " << ndist << "\n";
            q.push(BFSNode(v, cur.s, max(0LL, ndist)));
        }
    }

    return reachable;
}

int compute_value(int xL, int xR, int yL, int yR) {
    vector<int> costX(N, 0), costY(N, 0);
    
    costX[X] = 0;
    for (int i = X - 1; i >= xL; i--) {
        costX[i] = costX[i + 1] + weights[{i, i + 1}];
    }
    for (int i = X + 1; i <= xR; i++) {
        costX[i] = costX[i - 1] + weights[{i - 1, i}];
    }
    
    costY[Y] = 0;
    for (int i = Y - 1; i >= yL; i--) {
        costY[i] = costY[i + 1] + weights[{i, i + 1}];
    }
    for (int i = Y + 1; i <= yR; i++) {
        costY[i] = costY[i - 1] + weights[{i - 1, i}];
    }
    
    int total = 0;
    int unionL = min(xL, yL);
    int unionR = max(xR, yR);
    for (int i = unionL; i <= unionR; i++) {
        bool inX = (i >= xL && i <= xR);
        bool inY = (i >= yL && i <= yR);
        if (inX && inY) {
            total += max(costX[i], costY[i]);
        } else if (inX) {
            total += costX[i];
        } else if (inY) {
            total += costY[i];
        }
    }
    return total;
}

int solve_line() {
    int reachable = 0;
    if (X > Y) swap(X, Y);

    for (int xL = 0; xL <= X; xL++) {
        for (int xR = X; xR < N; xR++) {
            for (int yL = 0; yL <= Y; yL++) {
                for (int yR = Y; yR < N; yR++) {
                    int reached = (xR - xL + 1) + (yR - yL + 1);
                    // cout << "value: " << compute_value(xL, xR, yL, yR) << ", reached: " << reached << ", xrange = [" << xL << ", " << xR << "], yrange = [" << yL << ", " << yR << "]\n";
                    if (compute_value(xL, xR, yL, yR) <= K) {
                        reachable = max(reachable, reached);
                    }
                }
            }
        }
    }

    return reachable;
}

int solve_small() {
    int reachable = 0;
    return reachable;
}

void bfs_cap(int maxnodes, vl &tim, vb &vis, vl &xdist, int &seen, ll &cur_sum) {
    priority_queue<pli, vector<pli>, greater<pli>> q;
    q.push({0LL, X});

    while (!q.empty()) {
        int u = q.top().second;
        ll dist = q.top().first;
        q.pop();

        if (vis[u]) continue;
        if (seen >= maxnodes) break;
        vis[u] = true;
        seen++;

        cur_sum += dist;
        tim[u] = dist;

        for (int v : adj[u]) {
            if (vis[v]) continue;
            q.push({xdist[v], v});
        }
    }
}

void bfs_rem(vl &tim, vb &vis, vl &xdist, vl &ydist, int &seen, ll &cur_sum) {
    vb newvis(N, false);
    priority_queue<pli, vector<pli>, greater<pli>> q;
    q.push({0LL, Y});

    while (!q.empty()) {
        int u = q.top().second;
        ll dist = q.top().first;
        q.pop();

        if (newvis[u]) continue;
        if (cur_sum + dist > K) break;
        newvis[u] = true;

        seen++;
        cur_sum += dist;
        tim[u] = dist;

        for (int v : adj[u]) {
            if (newvis[v]) continue;
            ll ndist = ydist[v];
            if (vis[v]) ndist = max(0LL, ydist[v] - xdist[v]);
            q.push({ndist, v});
        }
    }
}

int get_reachable(int x_cap, vl &xdist, vl &ydist) {
    vl tim(N, 0LL);
    vb vis(N, false);
    int seen = 0;
    ll cur_sum = 0LL;

    bfs_cap(x_cap, tim, vis, xdist, seen, cur_sum);
    bfs_rem(tim, vis, xdist, ydist, seen, cur_sum);

    if (cur_sum > K) return 0;
    return seen;
}

int solve() {
    vl distX = compute_distances(X), distY = compute_distances(Y);

    
    int mx = 0;
    for (int i = 1; i <= N; i++) {
        int v = get_reachable(i, distX, distY);
        mx = max(mx, v);
    }

    return mx;
}

/*
    N: number of cities <= 200k
    X, Y: 2 cities of interest 0 < X != Y < N
    K: maximum allowed sum of closing times <= 10^18
    U, V: road[j] = (U[j], V[j])
    W: W[j] = length of road[W]
*/
int max_score(int _N, int _X, int _Y, ll _K, vi _U, vi _V, vi _W) {
    W.resize(0); weights.clear();
    N = _N; X = _X; Y = _Y; K = _K; U = _U; V = _V; for (int i = 0; i < N - 1; i++) W.push_back(ll(_W[i]));
    adj.assign(N, vi());

    // cout << "N = " << N << " X = " << X << " Y = " << Y << " K = " << K << "\n";

    bool line = true;
    for (int i = 0; i < N - 1; i++) {
        adj[U[i]].push_back(V[i]);
        adj[V[i]].push_back(U[i]);
        weights[{U[i], V[i]}] = W[i];
        weights[{V[i], U[i]}] = W[i];

        line = line && max(U[i], V[i]) == min(U[i], V[i]) + 1;
    }

    vl distX = compute_distances(X), distY = compute_distances(Y);
    
    int ans = 0;
    // if (distX[Y] > 2 * K) {
    //     ans = solve_disconnected();
    // } else if (line) {
    //     // cout << "line\n";
    //     ans = solve_line();
    // } else if (N <= 3000) {
    //     ans = solve_small();
    // } else {
    //     ans = solve();
    // }

    ans = solve();
    return ans;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...