Submission #726261

# Submission time Handle Problem Language Result Execution time Memory
726261 2023-04-18T17:21:02 Z vjudge1 Toll (APIO13_toll) C++17
16 / 100
4 ms 5204 KB
#include <bits/stdc++.h>
using namespace std;

int n, m, k;
int a[300000], b[300000], c[300000], ord[300000], cnt[300000];
int x[20], y[20], cost[20];
int64_t p[100000], pp[100000];
int64_t par[100000];
vector<pair<int, int>> adj[100000];
vector<pair<int64_t*, int64_t>> rollback;

inline int root(int u, bool save) { return par[u] < 0 ? u : (save ? root(par[u], save) : par[u] = root(par[u], save)); }

inline bool is_connected(int u, int v) { return root(u, 1) == root(v, 1); }

inline bool unite(int u, int v, bool save = 0) {
        u = root(u, save), v = root(v, save);
        if (u == v) return 0;
        if (par[u] > par[v]) swap(u, v);
        if (save) rollback.emplace_back(&par[u], par[u]);
        if (save) rollback.emplace_back(&par[v], par[v]);
        if (save) rollback.emplace_back(&p[u], p[u]);
        p[u] += p[v];
        par[u] += par[v];
        par[v] = u;
        return 1;
}

int64_t ans;

int64_t dfs(int u, int pr) {
        int64_t res = 0;
        for (pair<int, int>& e : adj[u]) {
                if (e.first == pr) continue;
                int64_t dpv = dfs(e.first, u);
                res += dpv;
                ans += 1ll * dpv * e.second;
        }
        return res + p[u];
}

int main() {
        cin >> n >> m >> k;
        for (int i = 0; i < m; i++) cin >> a[i] >> b[i] >> c[i], a[i]--, b[i]--;
        for (int i = 0; i < k; i++) cin >> x[i] >> y[i], x[i]--, y[i]--;
        for (int i = 0; i < n; i++) cin >> pp[i];

        memset(par, -1, n * sizeof *par);
        iota(ord, ord + m, 0);
        sort(ord, ord + m, [&](int a, int b) {
                return c[a] < c[b];
        });
        for (int j = 0; j < m; j++) {
                cnt[ord[j]] += unite(a[ord[j]], b[ord[j]]);
        }  // spanning tree with 0 new edge

        memset(par, -1, n * sizeof *par);
        for (int i = 0; i < k; i++) {
                unite(x[i], y[i]);
        }
        for (int j = 0; j < m; j++) {
                cnt[ord[j]] += unite(a[ord[j]], b[ord[j]]);
        }  // spanning tree with k new edges

        vector<int> almost_mst;
        memset(par, -1, n * sizeof *par);
        for (int i = 0; i < n; i++) p[i] = pp[i];
        for (int i = 0; i < m; i++) {
                if (cnt[ord[i]] == 1) almost_mst.emplace_back(ord[i]);
                if (cnt[ord[i]] == 2) unite(a[ord[i]], b[ord[i]]);
        }
        for (int i = 0; i < n; i++) root(i, 0);

        assert(almost_mst.size() <= k);

        int64_t res = 0;

        for (int i = 1; i < 1 << k; i++) {
                bool ok = 1;
                for (int j = 0; j < k; j++) {
                        if (i >> j & 1) {
                                if (!unite(x[j], y[j], 1)) ok = 0;
                        }
                }

                if (!ok) {
                        while (rollback.size()) {
                                auto& p = rollback.back();
                                *p.first = p.second;
                                rollback.pop_back();
                        }
                        continue;
                }

                vector<int> not_mst, mst;

                for (int j : almost_mst) {
                        if (!unite(a[j], b[j], 1)) {
                                not_mst.emplace_back(j);
                        } else {
                                mst.emplace_back(j);
                        }
                }

                while (rollback.size()) {
                        auto& p = rollback.back();
                        *p.first = p.second;
                        rollback.pop_back();
                }

                for (int j : mst) unite(a[j], b[j], 1);

                int roll_back_size = rollback.size();
                for (int j = 0; j < k; j++) {
                        if (i >> j & 1) {
                                for (int z = 0; z < k; z++) {
                                        if ((i >> z & 1) && j != z) unite(x[z], y[z], 1);
                                }
                        }
                        cost[j] = 1e9;
                        for (int z : not_mst) {
                                if (!is_connected(a[z], b[z])) cost[j] = min(cost[j], c[z]);
                        }
                        while (rollback.size() > roll_back_size) {
                                auto& p = rollback.back();
                                *p.first = p.second;
                                rollback.pop_back();
                        }
                }

                for (int j = 0; j < k; j++) {
                        if (i >> j & 1) {
                                adj[root(x[j], 1)].emplace_back(root(y[j], 1), cost[j]);
                                adj[root(y[j], 1)].emplace_back(root(x[j], 1), cost[j]);
                        }
                }

                ans = 0;
                dfs(root(0, 1), -1);
                res = max(res, ans);

                while (rollback.size()) {
                        auto& p = rollback.back();
                        *p.first = p.second;
                        rollback.pop_back();
                }

                for (int j = 0; j < k; j++) {
                        if (i >> j & 1) {
                                adj[root(x[j], 1)].pop_back();
                                adj[root(y[j], 1)].pop_back();
                        }
                }
        }

        cout << res;
}

Compilation message

In file included from /usr/include/c++/10/cassert:44,
                 from /usr/include/x86_64-linux-gnu/c++/10/bits/stdc++.h:33,
                 from toll.cpp:1:
toll.cpp: In function 'int main()':
toll.cpp:74:34: warning: comparison of integer expressions of different signedness: 'std::vector<int>::size_type' {aka 'long unsigned int'} and 'int' [-Wsign-compare]
   74 |         assert(almost_mst.size() <= k);
      |                ~~~~~~~~~~~~~~~~~~^~~~
toll.cpp:124:48: warning: comparison of integer expressions of different signedness: 'std::vector<std::pair<long int*, long int> >::size_type' {aka 'long unsigned int'} and 'int' [-Wsign-compare]
  124 |                         while (rollback.size() > roll_back_size) {
      |                                ~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~
# Verdict Execution time Memory Grader output
1 Correct 2 ms 2644 KB Output is correct
2 Correct 2 ms 2644 KB Output is correct
# Verdict Execution time Memory Grader output
1 Correct 2 ms 2644 KB Output is correct
2 Correct 2 ms 2644 KB Output is correct
3 Runtime error 4 ms 5204 KB Execution killed with signal 11
4 Halted 0 ms 0 KB -
# Verdict Execution time Memory Grader output
1 Correct 2 ms 2644 KB Output is correct
2 Correct 2 ms 2644 KB Output is correct
3 Runtime error 4 ms 5204 KB Execution killed with signal 11
4 Halted 0 ms 0 KB -
# Verdict Execution time Memory Grader output
1 Correct 2 ms 2644 KB Output is correct
2 Correct 2 ms 2644 KB Output is correct
3 Runtime error 4 ms 5204 KB Execution killed with signal 11
4 Halted 0 ms 0 KB -
# Verdict Execution time Memory Grader output
1 Correct 2 ms 2644 KB Output is correct
2 Correct 2 ms 2644 KB Output is correct
3 Runtime error 4 ms 5204 KB Execution killed with signal 11
4 Halted 0 ms 0 KB -