답안 #824218

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
824218 2023-08-13T19:04:45 Z vjudge1 Mergers (JOI19_mergers) C++17
0 / 100
87 ms 54768 KB
#include <bits/stdc++.h>
using namespace std;
// 123

static const int N = 1e6 + 10;
class lca_t {
       public:
        int rmq[20][N];
        int pos[N], d[N];
        vector<int> adj[N];
        int timer;
        int n;

        lca_t(int n = 0) : n(n) {}

        void add_edge(int u, int v) {
                adj[u].emplace_back(v);
                adj[v].emplace_back(u);
        }

        void dfs(int u, int p) {
                rmq[0][timer++] = u;
                for (int v : adj[u]) {
                        if (v == p) continue;
                        d[v] = d[u] + 1;
                        dfs(v, u);
                        rmq[0][timer++] = u;
                }
        }

        void build(int root) {
                timer = 0;
                d[root] = 0;
                dfs(root, -1);
                for (int i = 0; i < timer; i++) pos[rmq[0][i]] = i;
                assert(__lg(timer) < 20);
                for (int i = 1; i <= __lg(timer); i++) {
                        for (int j = 0; j + (1 << i) <= timer; j++) {
                                rmq[i][j] = d[rmq[i - 1][j]] < d[rmq[i - 1][j + (1 << (i - 1))]]
                                                ? rmq[i - 1][j]
                                                : rmq[i - 1][j + (1 << (i - 1))];
                        }
                }
        }

        int get(int u, int v) {
                int l = pos[u], r = pos[v];
                if (l > r) swap(l, r);
                int i = __lg(r - l + 1);
                r -= (1 << i) - 1;
                return d[rmq[i][l]] < d[rmq[i][r]] ? rmq[i][l] : rmq[i][r];
        }

} lca;

class DSU {
       public:
        DSU(int n) {
                prt.resize(n + 1, -1);
        }

        int SZ(int u) {
                return -prt[root(u)];
        }

        int root(int u) {
                return (prt[u] < 0 ? u : (prt[u] = root(prt[u])));
        }

        bool connected(int u, int v) {
                return ((u = root(u)) == (v = root(v)));
        }

        bool unite(int u, int v) {
                if ((u = root(u)) == (v = root(v))) {
                        return 0;
                }

                if (prt[u] > prt[v]) {
                        u ^= v ^= u ^= v;
                }

                prt[u] += prt[v];
                prt[v] = u;
                return 1;
        }

       private:
        vector<int> prt;
};

int32_t main() {
        ios_base::sync_with_stdio(0);
        cin.tie(0);
        int n, k;
        cin >> n >> k;
        vector<vector<int>> adj(n);
        for (int i = 0; i < n - 1; i++) {
                int u, v;
                cin >> u >> v;
                u--, v--;
                adj[u].emplace_back(v);
                adj[v].emplace_back(u);
                lca.add_edge(u, v);
        }
        vector<vector<int>> group(k);
        vector<int> a(n);
        for (int i = 0; i < n; i++) cin >> a[i], a[i]--;
        for (int i = 0; i < n; i++) group[a[i]].emplace_back(i);
        lca.build(0);
        vector<int> pf(n, 0);
        for (int i = 0; i < k; i++) {
                int top = group[i][0];
                for (int j : group[i]) top = lca.get(top, j);
                pf[top] -= group[i].size();
                for (int j : group[i]) pf[j]++;
        }
        DSU dsu(n);
        function<void(int, int)> dfs = [&](int u, int p) {
                for (int v : adj[u]) {
                        if (v == p) continue;
                        dfs(v, u);
                        pf[u] += pf[v];
                        if (pf[v]) {
                                dsu.unite(u, v);
                        }
                }
        };
        dfs(0, -1);
        vector<int> deg(n);
        for (int i = 0; i < n; i++) {
                for (int j : adj[i]) {
                        if (i > j) continue;
                        int x = dsu.root(i), y = dsu.root(j);
                        deg[x]++, deg[y]++;
                }
        }
        int res = 0;
        for (int i = 0; i < n; i++) res += deg[i] == 1;
        cout << (res + 1 >> 1);
}

Compilation message

mergers.cpp: In function 'int32_t main()':
mergers.cpp:140:22: warning: suggest parentheses around '+' inside '>>' [-Wparentheses]
  140 |         cout << (res + 1 >> 1);
      |                  ~~~~^~~
# 결과 실행 시간 메모리 Grader output
1 Correct 11 ms 23764 KB Output is correct
2 Incorrect 10 ms 23892 KB Output isn't correct
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 11 ms 23764 KB Output is correct
2 Incorrect 10 ms 23892 KB Output isn't correct
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 11 ms 23764 KB Output is correct
2 Incorrect 10 ms 23892 KB Output isn't correct
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 59 ms 49804 KB Output is correct
2 Correct 87 ms 54768 KB Output is correct
3 Incorrect 13 ms 24660 KB Output isn't correct
4 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 11 ms 23764 KB Output is correct
2 Incorrect 10 ms 23892 KB Output isn't correct
3 Halted 0 ms 0 KB -