Submission #999154

# Submission time Handle Problem Language Result Execution time Memory
999154 2024-06-15T07:26:46 Z mdn2002 Closing Time (IOI23_closing) C++17
Compilation error
0 ms 0 KB
/*
Mayoeba Yabureru
*/
#include<bits/stdc++.h>
using namespace std;

long long dp[3003][6003][2][2];
int max_score(int n, int x, int y, long long k, vector<int> U, vector<int> V, vector<int> W) {
    x ++, y ++;
    vector dis(n + 1, vector<long long>(2));
    vector<vector<pair<int, long long>>> gr(n + 1);

    for (int i = 0; i < n - 1; i ++) {
        int u = U[i] + 1, v = V[i] + 1, w = W[i];
        gr[u].push_back({v, w});
        gr[v].push_back({u, w});
    }

    function<void(int, int, int)> dfs = [&] (int v, int p, int wt) {
        for (auto [u, w] : gr[v]) {
            if (u == p) continue;
            dis[u][wt] = dis[v][wt] + w;
            dfs(u, v, wt);
        }
    };
    dfs(x, 0, 0), dfs(y, 0, 1);

    vector<int> vc, path, onpath(n + 1);
    function<void(int, int)> go = [&] (int v, int p) {
        vc.push_back(v);
        if (v == y) path = vc;
        for (auto [u, w] : gr[v]) {
            if (u == p) continue;
            go(u, v);
        }
        vc.pop_back();
    };
    go(x, 0);

    long long minstart = 0, mn = 1e15;
    for (int i = 1; i <= n; i ++) {
        if (max(dis[i][0], dis[i][1]) < mn) {
            minstart = i;
            mn = max(dis[i][0], dis[i][1]);
        }
    }

    long long sum = max(dis[minstart][0], dis[minstart][1]);
    for (auto x : path) {
        onpath[x] += 1;
        if (x == minstart) break;
        sum += dis[x][0];
    }
    for (int i = path.size() - 1; i >= 0; i --) {
        int x = path[i];
        onpath[x] += 2;
        if (x == minstart) break;
        sum += dis[x][1];
    }

    int ans = 0;
    function f = [&] {
        vector<int> did(n + 1);
        multiset<pair<long long, int>> s;
        for (int i = 1; i <= n; i ++) {
            s.insert({dis[i][0], i});
            s.insert({dis[i][1], i});
        }
        int cnt = 0;
        vector<int> v;

        while (s.size()) {
            auto [x, y] = *s.begin();
            s.erase(s.begin());
            if (did[y]) continue;
            if (x > k) break;
            k -= x;
            did[y] = 1;
            cnt ++;
        }
        if (k >= 0) return cnt;
        return -100000000;
    };

    long long kk = k;
    ans = f();
    if (sum > k) return ans;
    k = kk;
    for (int i = 1; i <= n; i ++) {
        for (int j = 0; j <= 2 * n; j ++) {
            dp[i][j][0][0] = dp[i][j][0][1] = dp[i][j][1][0] = dp[i][j][1][1] = 1e16;
        }
    }
    vector dpp(2 * (n + 1), vector (2, vector<long long>(2, 1e16)));

    vector<int> siz(n + 1);
    function<void(int, int)> calcsiz = [&] (int x, int p) {
        siz[x] = 1;
        for (auto [u, w] : gr[x]) {
            if (u == p) continue;
            calcsiz(u, x);
            siz[x] += siz[u];
        }
    };
    calcsiz(minstart, 0);

    map<vector<int>, int> did;
    function<void(int, int, int, int)> calc = [&] (int x, int p, int rx, int ry) {
        if (did[{x, rx, ry}]) return;
        did[{x, rx, ry}] = 1;
        if (rx == 0 && ry == 0) dp[x][0][rx][ry] = 0;
        if (rx && ry == 0) dp[x][1][rx][ry] = min(dp[x][1][rx][ry], dis[x][0]);
        if (ry && rx == 0) dp[x][1][rx][ry] = min(dp[x][1][rx][ry], dis[x][1]);
        if (rx && ry) dp[x][2][rx][ry] = min(dp[x][2][rx][ry], max(dis[x][0], dis[x][1]));

        int sumsiz = 1;
        for (auto [u, w] : gr[x]) {
            if (u == p) continue;
            calc(u, x, rx, ry);
            if (rx && ry) {
                if (onpath[u] == 0) calc(u, x, 0, 0);
                if (onpath[u] != 1) calc(u, x, 0, 1);
                if (onpath[u] != 2) calc(u, x, 1, 0);
            }
            else if (rx && onpath[u] != 1) calc(u, x, 0, ry);
            else if (ry && onpath[u] != 2) calc(u, x, rx, 0);

            for (int i = 0; i <= rx; i ++) {
                for (int j = 0; j <= ry; j ++) {
                    for (int z = 2 * siz[u]; z >= 0; z --) {
                        for (int k = 2 * sumsiz; k >= 0; k --) {
                            dpp[k + z][rx][ry] = min(dpp[k + z][rx][ry], dp[x][k][rx][ry] + dp[u][z][i][j]);
                        }
                    }
                }
            }
            sumsiz += siz[u];
            for (int z = 0; z <= 2 * sumsiz; z ++) {
                dp[x][z][rx][ry] = dpp[z][rx][ry];
                dpp[z][rx][ry] = 1e16;
            }
        }
    };
    calc(minstart, 0, 1, 1);

    for (int i = 0; i <= 2 * n; i ++) {
        if (k >= dp[minstart][i][1][1]) ans = max(ans, i);
    }
    return ans;
}
/*
2
7 0 2 10
0 1 2
0 3 3
1 2 4
2 4 2
2 5 5
5 6 3

4 0 3 20
0 1 18
1 2 1
2 3 19
*/

#define sz(x) int(x.size())
#define all(x) begin(x), end(x)

using ll = long long;
using ii = pair<int, int>;

const ll INF = 1e18;
const int MAX_N = 2e5 + 9;

vector<ii> adj[MAX_N];
pair<ll, ll> val[MAX_N];
int n;

void dfs(int u, vector<ll> &dist, int p = -1, ll d = 0) {
    dist[u] = d;
    for (auto [v, w] : adj[u]) {
        if (v != p) dfs(v, dist, u, d + w);
    }
}

int non_overlaping(ll K) {
    vector<ll> v(n);
    for (int i = 0; i < n; i++) {
        v[i] = val[i].first;
    }
    sort(all(v));
    int ans = 0;
    for (int i = 0; i < n; i++) {
        if (v[i] <= K) {
            K -= v[i];
            ans++;
        }
    }
    return ans;
}

int max_score1(int N, int X, int Y, ll K, vector<int> U, vector<int> V, vector<int> W) {
    n = N;
    for (int i = 0; i < n; i++) adj[i].clear();
    for (int i = 0; i < n - 1; i++) {
        adj[U[i]].emplace_back(V[i], W[i]);
        adj[V[i]].emplace_back(U[i], W[i]);
    }
    vector<ll> distToX(n), distToY(n);
    dfs(X, distToX), dfs(Y, distToY);
    bool flag = false;
    for (int i = 0; i < n; i++) {
        val[i] = minmax(distToX[i], distToY[i]);
        if (val[i].second <= K) flag = true;
    }
    if (!flag) return non_overlaping(K);
    vector<bool> path(n, false);
    path[Y] = true;
    int u = Y;
    while (u != X) {
        for (auto [v, w] : adj[u]) {
            if (distToX[v] + w == distToX[u]) {
                u = v;
                break;
            }
        }
        path[u] = true;
    }
    vector<ll> prevDP(2 * n + 1, INF);
    prevDP[0] = 0;
    for (int i = 0; i < n; i++) {
        vector<ll> currDP(2 * n + 1, INF);
        if (!path[i]) currDP = prevDP;
        for (int j = 1; j <= 2 * n; j++) {
            currDP[j] = min(currDP[j], prevDP[j - 1] + val[i].first);
        }
        for (int j = 2; j <= 2 * n; j++) {
            currDP[j] = min(currDP[j], prevDP[j - 2] + val[i].second);
        }
        prevDP = currDP;
    }
    for (int i = 2 * n; i >= 0; i--) {
        if (prevDP[i] <= K) return max(i, non_overlaping(K));
    }
    return non_overlaping(K);
}

int main()
{
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);
    int T = 1;
    cin >> T;
    while (T --) {
        int n, x, y;
        long long k;
        cin >> n >> x >> y >> k;
        vector<int> U(n - 1), V(n - 1), W(n - 1);
        for (int i = 0; i < n - 1; i ++) cin >> U[i] >> V[i] >> W[i];
        cout << max_score(n, x, y, k, U, V, W) << endl;
        cout << max_score1(n, x, y, k, U, V, W) << endl;
    }
}

/*
1
12 3 8 500
0 1 2
1 2 34
2 3 23
3 4 22
4 5 12
5 6 2
6 7 6
7 8 9
8 9 21
9 10 123
10 11 23

*/

Compilation message

/usr/bin/ld: /tmp/cczHdHHg.o: in function `main':
grader.cpp:(.text.startup+0x0): multiple definition of `main'; /tmp/ccGxljaj.o:closing.cpp:(.text.startup+0x0): first defined here
collect2: error: ld returned 1 exit status