제출 #1368759

#제출 시각아이디문제언어결과실행 시간메모리
1368759SulAMergers (JOI19_mergers)C++20
0 / 100
99 ms56020 KiB
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
#define bitcount __builtin_popcountll
#define all(a) (a).begin(), (a).end()
using namespace std;
using namespace __gnu_pbds;
template<typename T> using ordered_set = tree<T, null_type, less_equal<T>, rb_tree_tag, tree_order_statistics_node_update>;

const int N = 5e5 + 1;
vector<pair<int,int>> adj[N];
unordered_map<int,int> stl[N];
bool sep[N];
int tot[N], region[N], deg[N];

int par[N];

int find(int u) { return par[u] == u ? u : par[u] = find(par[u]); }
void merge(int u, int v) {
    assert(find(u) != find(v));
    par[find(v)] = find(u);
}


void dfs1(int u, int p = -1) {
    if (tot[stl[u].begin()->first] == 1)
        stl[u].clear();
    for (auto [v, ind] : adj[u]) {
        if (v == p) continue;
        dfs1(v, u);
        if (stl[v].empty()) {
            sep[ind] = true;
//            cout << u << " " << v << "\n";
        }

        if (stl[u].size() < stl[v].size())
            swap(stl[u], stl[v]);
        for (auto [col, frq] : stl[v]) {
            auto [it, _] = stl[u].emplace(col, 0);
            it->second += frq;
            if (it->second == tot[col]) {
                stl[u].erase(it);
            }
        }
    }
}

int timer = 0;
void dfs2(int u, int p = -1) {
    region[u] = timer;
    for (int b = 0; b <= 1; b++) {
        for (auto [v, ind]: adj[u]) {
            if (v == p) continue;
            timer += sep[ind];
            if (sep[ind] == b) dfs2(v, u);
        }
    }
}

signed main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr), cout.tie(nullptr);

    int n, k; cin >> n >> k;
    for (int i = 0, u, v; i < n-1; i++) {
        cin >> u >> v;
        adj[u].emplace_back(v, i);
        adj[v].emplace_back(u, i);
    }
    for (int i = 1, col; i <= n; i++) {
        cin >> col;
        tot[col]++;
        stl[i][col]++;
    }
    dfs1(1);
    dfs2(1);
    set<pair<int,int>> s;
    for (int u = 1; u <= n; u++) {
        for (auto [v, ind] : adj[u]) {
            if (sep[ind]) {
                int x = region[u];
                int y = region[v];
                s.emplace(min(x, y), max(x, y));
            }
        }
    }
    for (int i = 0; i < N; i++)
        par[i] = i;
    for (auto [x, y] : s) {
        deg[x]++, deg[y]++;
        merge(x, y);
    }
    int ans = 0;
    for (int i = 1; i <= n; i++)
        ans += deg[i] == 1;
    cout << (ans + 1)/2;
}
#결과 실행 시간메모리채점기 출력
결과를 불러오는 중입니다…
#결과 실행 시간메모리채점기 출력
결과를 불러오는 중입니다…
#결과 실행 시간메모리채점기 출력
결과를 불러오는 중입니다…
#결과 실행 시간메모리채점기 출력
결과를 불러오는 중입니다…
#결과 실행 시간메모리채점기 출력
결과를 불러오는 중입니다…