Submission #1360881

#TimeUsernameProblemLanguageResultExecution timeMemory
1360881tickcrossyClosing Time (IOI23_closing)C++20
75 / 100
1096 ms33300 KiB
#include <vector>
#include <algorithm>
#include <cmath>
#include <queue>

using namespace std;

const long long INF = 2e18;

struct Edge {
    int to;
    long long w;
    bool on_P;
};

long long max_score(int N, int X, int Y, long long K, std::vector<int> U, std::vector<int> V, std::vector<int> W) {
    vector<vector<pair<int, long long>>> adj(N);
    for (int i = 0; i < N - 1; ++i) {
        adj[U[i]].push_back({V[i], W[i]});
        adj[V[i]].push_back({U[i], W[i]});
    }

    vector<long long> dX(N, -1), dY(N, -1);
    auto bfs = [&](int start, vector<long long>& dist) {
        priority_queue<pair<long long, int>, vector<pair<long long, int>>, greater<pair<long long, int>>> pq;
        pq.push({0, start});
        dist[start] = 0;
        while (!pq.empty()) {
            auto [d, u] = pq.top();
            pq.pop();
            if (d > dist[u]) continue;
            for (auto& edge : adj[u]) {
                int v = edge.first;
                long long w = edge.second;
                if (dist[v] == -1 || dist[v] > dist[u] + w) {
                    dist[v] = dist[u] + w;
                    pq.push({dist[v], v});
                }
            }
        }
    };
    bfs(X, dX);
    bfs(Y, dY);

    if (N <= 3000) {  // 精确的树形背包DP,解决绝大部分子任务
        vector<int> parent(N, -1);
        queue<int> q;
        q.push(X);
        vector<bool> vis(N, false);
        vis[X] = true;
        while (!q.empty()) {
            int u = q.front();
            q.pop();
            if (u == Y) break;
            for (auto& edge : adj[u]) {
                int v = edge.first;
                if (!vis[v]) {
                    vis[v] = true;
                    parent[v] = u;
                    q.push(v);
                }
            }
        }

        vector<bool> is_on_P(N, false);
        int curr = Y;
        while (curr != -1) {
            is_on_P[curr] = true;
            curr = parent[curr];
        }

        vector<vector<Edge>> tree(N);
        auto build_tree = [&](auto& self, int u, int p) -> void {
            for (auto& edge : adj[u]) {
                int v = edge.first;
                if (v == p) continue;
                bool edge_on_P = is_on_P[u] && is_on_P[v];
                tree[u].push_back({v, edge.second, edge_on_P});
                self(self, v, u);
            }
        };
        build_tree(build_tree, X, -1);

        auto dfs_dp = [&](auto& self, int u) -> vector<vector<long long>> {
            vector<vector<long long>> dp(4, vector<long long>(3, INF));
            dp[0][0] = 0;
            dp[1][1] = dX[u];
            dp[2][1] = dY[u];
            dp[3][2] = max(dX[u], dY[u]);
            int sz_u = 1;

            for (auto& edge : tree[u]) {
                int v = edge.to;
                auto dp_v = self(self, v);
                int sz_v = (dp_v[0].size() - 1) / 2;
                vector<vector<long long>> next_dp(4, vector<long long>(2 * (sz_u + sz_v) + 1, INF));

                for (int su = 0; su < 4; ++su) {
                    for (int sv = 0; sv < 4; ++sv) {
                        bool ok = true;
                        // X树约束验证
                        if (sv == 1 && !(su == 1 || su == 3)) ok = false;
                        if (sv == 3 && !(su == 1 || su == 3)) ok = false;
                        // Y树约束验证
                        if (edge.on_P) {
                            if (su == 2 && !(sv == 2 || sv == 3)) ok = false;
                            if (su == 3 && !(sv == 2 || sv == 3)) ok = false;
                        } else {
                            if (sv == 2 && !(su == 2 || su == 3)) ok = false;
                            if (sv == 3 && !(su == 2 || su == 3)) ok = false;
                        }
                        if (!ok) continue;

                        for (int cu = 0; cu <= 2 * sz_u; ++cu) {
                            if (dp[su][cu] == INF) continue;
                            for (int cv = 0; cv <= 2 * sz_v; ++cv) {
                                if (dp_v[sv][cv] == INF) continue;
                                next_dp[su][cu + cv] = min(next_dp[su][cu + cv], dp[su][cu] + dp_v[sv][cv]);
                            }
                        }
                    }
                }
                sz_u += sz_v;
                dp = move(next_dp);
            }
            return dp;
        };

        auto final_dp = dfs_dp(dfs_dp, X);
        long long ans = 0;
        for (int s = 0; s < 4; ++s) {
            for (int c = 0; c < final_dp[s].size(); ++c) {
                if (final_dp[s][c] <= K) ans = max(ans, (long long)c);
            }
        }
        return ans;
        
    } else { // 贪心兜底,用以规避大节点O(N^2)超时(利用边际代价推导)
        vector<int> parent(N, -1);
        queue<int> q; q.push(X);
        vector<bool> vis(N, false);
        vis[X] = true;
        while (!q.empty()) {
            int u = q.front(); q.pop();
            if (u == Y) break;
            for (auto& edge : adj[u]) {
                int v = edge.first;
                if (!vis[v]) {
                    vis[v] = true;
                    parent[v] = u;
                    q.push(v);
                }
            }
        }

        vector<bool> on_path(N, false);
        int curr = Y;
        while (curr != -1) {
            on_path[curr] = true;
            curr = parent[curr];
        }

        vector<int> root_p(N, -1);
        q = queue<int>();
        for (int i = 0; i < N; ++i) {
            if (on_path[i]) {
                root_p[i] = i;
                q.push(i);
            }
        }
        
        while (!q.empty()) {
            int u = q.front(); q.pop();
            for (auto& edge : adj[u]) {
                int v = edge.first;
                if (root_p[v] == -1) {
                    root_p[v] = root_p[u];
                    q.push(v);
                }
            }
        }

        vector<long long> ones;
        struct TwoItem {
            long long cost, brk;
            bool operator<(const TwoItem& other) const { return cost < other.cost; }
        };
        vector<TwoItem> twos;

        for (int i = 0; i < N; ++i) {
            int p = root_p[i];
            long long delta = abs(dX[p] - dY[p]);
            long long v = min(dX[i], dY[i]);

            if (v < delta) {
                ones.push_back(v);
                ones.push_back(delta);
            } else {
                twos.push_back({v + delta, v});
            }
        }

        sort(ones.begin(), ones.end());
        sort(twos.begin(), twos.end());

        vector<long long> P1(ones.size() + 1, 0);
        for (size_t i = 0; i < ones.size(); ++i) P1[i + 1] = P1[i] + ones[i];
        vector<long long> P2(twos.size() + 1, 0);
        for (size_t i = 0; i < twos.size(); ++i) P2[i + 1] = P2[i] + twos[i].cost;

        vector<long long> min_break(twos.size() + 1, 2e18);
        for (int i = (int)twos.size() - 1; i >= 0; --i) {
            min_break[i] = min(min_break[i + 1], twos[i].brk);
        }

        long long max_pts = 0;
        for (size_t c = 0; c <= ones.size(); ++c) {
            long long budget_left = K - P1[c];
            if (budget_left < 0) break;

            int low = 0, high = twos.size(), c2 = 0;
            while (low <= high) {
                int mid = low + (high - low) / 2;
                if (P2[mid] <= budget_left) {
                    c2 = mid;
                    low = mid + 1;
                } else {
                    high = mid - 1;
                }
            }

            long long current_pts = c + 2LL * c2;
            max_pts = max(max_pts, current_pts);
            if (c2 < (int)twos.size()) {
                long long budget_after_twos = budget_left - P2[c2];
                if (budget_after_twos >= min_break[c2]) {
                    max_pts = max(max_pts, current_pts + 1);
                }
            }
        }

        return max_pts;
    }
}
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...
#Result Execution timeMemoryGrader output
Fetching results...