답안 #998740

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
998740 2024-06-14T15:40:21 Z mdn2002 봉쇄 시간 (IOI23_closing) C++17
컴파일 오류
0 ms 0 KB
/*
Mayoeba Yabureru
*/
//#include "closing.h"
#include<bits/stdc++.h>
using namespace std;

int max_score(int n, int x, int y, long long k, vector<int> U, vector<int> V, vector<int> W) {
    x ++, y ++;
    vector dis(n + 1, vector<long long>(2));
    vector<vector<pair<int, long long>>> gr(n + 1);

    for (int i = 0; i < n - 1; i ++) {
        int u = U[i] + 1, v = V[i] + 1, w = W[i];
        gr[u].push_back({v, w});
        gr[v].push_back({u, w});
    }

    function<void(int, int, int)> dfs = [&] (int v, int p, int wt) {
        for (auto [u, w] : gr[v]) {
            if (u == p) continue;
            dis[u][wt] = dis[v][wt] + w;
            dfs(u, v, wt);
        }
    };
    dfs(x, 0, 0), dfs(y, 0, 1);

    multiset<pair<long long, int>> ms;

    vector<int> did(n + 1);
    int ans = 0;

    function f = [&] {
        multiset<pair<long long, int>> s;
        for (int i = 1; i <= n; i ++) {
            if (did[i]) continue;
            s.insert({dis[i][0], i});
            s.insert({dis[i][1], i});
        }
        int cnt = 0;
        vector<int> v;
        function<void(int, int, int)> go = [&] (int x, int p, int wt) {
            v.push_back(x);
            if (did[x] == 2) {
                for (auto u : v) {
                    if (did[u] == 0) {
                        did[u] = 1;
                        k -= dis[u][wt];
                        cnt ++;
                    }
                }
            }
            for (auto [u, w] : gr[x]) {
                if (u == p) continue;
                go(u, x, wt);
            }
            v.pop_back();
        };
        go(x, 0, 0), go(y, 0, 1);

        while (s.size()) {
            auto [x, y] = *s.begin();
            s.erase(s.begin());
            if (did[y]) continue;
            if (x > k) break;
            k -= x;
            did[y] = 1;
            cnt ++;
        }
        if (k >= 0) return cnt;
        return -1e9;
    };

    long long kk = k;
    ans = f();

    for (int i = 1; i <= n; i ++) did[i] = 0;
    int cnt = 0;
    k = kk;

    for (int i = 1; i <= n; i ++) ms.insert({max(dis[i][0], dis[i][1]), i});

    while (ms.size()) {
        vector<int> odid = did;
        long long kk = k;
        ans = max(ans, cnt + f());

        did = odid;
        k = kk;
        auto [x, y] = *ms.begin();
        ms.erase(ms.begin());
        if (x > k) break;
        k -= x;
        did[y] = 2;
        cnt += 2;
    }

    return ans;
}
/*
2
7 0 2 10
0 1 2
0 3 3
1 2 4
2 4 2
2 5 5
5 6 3

4 0 3 20
0 1 18
1 2 1
2 3 19
*/
int main()
{
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);
    int T = 1;
    cin >> T;
    while (T --) {
        int n, x, y;
        long long k;
        cin >> n >> x >> y >> k;
        vector<int> U(n - 1), V(n - 1), W(n - 1);
        for (int i = 0; i < n - 1; i ++) cin >> U[i] >> V[i] >> W[i];
        cout << max_score(n, x, y, k, U, V, W) << endl;
    }
}

Compilation message

closing.cpp: In lambda function:
closing.cpp:71:16: error: inconsistent types 'int' and 'double' deduced for lambda return type
   71 |         return -1e9;
      |                ^~~~