답안 #796650

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
796650 2023-07-28T15:20:34 Z rnl42 Mergers (JOI19_mergers) C++14
컴파일 오류
0 ms 0 KB
#include <iostream>
#include <vector>
#include <algorithm>
#include <numeric>
using namespace std;

const int MAXN = 5e5;

int N, K;
int dist[MAXN];
vector<int> adj[MAXN];
int group_of[MAXN];
vector<int> members[MAXN];
struct lcatreeitem {
    int i;
    lcatreeitem& operator=(int other) {
        i = other;
        return *this;
    }
    operator int() const {
        return i;
    }
    bool operator<(const lcatreeitem& other) const {
        if (i == -1) return false;
        else if (other.i == -1) return true;
        assert(i >= 0 && other.i >= 0 && i < N && other.i < N);
        else return dist[i] < dist[other.i];
    }
} lcatree[1<<20];
int first[MAXN];
int lcatree_i = 0;
const int lcatree_shift = 1<<19;
int mergeuntil[MAXN];
bool penible[MAXN];

int uf[MAXN];

int root(int u) {
    return uf[u] == u ? u : uf[u] = root(uf[u]);
}
void merge(int u, int v) {
    if (dist[u] > dist[v]) {
        swap(u, v);
    }
    uf[root(v)] = root(u);
}

void dfs(int u) {
    first[u] = lcatree_i;
    lcatree[lcatree_shift+lcatree_i++] = u;
    for (int v : adj[u]) {
        adj[v].erase(find(adj[v].begin(), adj[v].end(), u));
        dist[v] = dist[u]+1;
        dfs(v);
        lcatree[lcatree_shift+lcatree_i++] = u;
    }
}

int dfs2(int u) {
    int ret = mergeuntil[u];
    for (int v : adj[u]) {
        int r = dfs2(v);
        if (dist[r] < dist[ret]) ret = r;
        if (r != v) {
            merge(v, u);
        }
    }
    return ret;
}

int lca(int u, int v) {
    if (first[u] > first[v]) swap(u, v);
    int l = lcatree_shift+first[u], r = lcatree_shift+first[v]+1;
    lcatreeitem ret;
    ret = -1;
    for (; l < r; l >>= 1, r >>= 1) {
        if (l&1) {
            ret = min(ret, lcatree[l++]);
        }
        if (r&1) {
            ret = min(ret, lcatree[--r]);
        }
    }
    //cerr << "lca(" << u << ", " << v << ") = " << (int)ret << '\n';
    return ret;
}

int dfs3(int u) {
    int ret = 0;
    for (int v : adj[u]) {
        ret += dfs3(v);
    }
    ret = max(ret, (int)penible[u]);
    return ret;
}

signed main() {
    ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
    cin >> N >> K;
    //dist.resize(N);
    //adj.resize(N);
    //first.resize(N);
    //group_of.resize(N);
    //members.resize(K);
    //lcatree.resize(1<<20);
    //uf.resize(N);
    //mergeuntil.resize(N);
    //penible.resize(N);
    iota(mergeuntil, mergeuntil+N, 0);
    iota(uf, uf+N, 0);
    int u, v;
    for (int i = 0; i < N-1; i++) {
        cin >> u >> v, u--, v--;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    for (int i = 0; i < N; i++) {
        cin >> group_of[i], group_of[i]--;
        members[group_of[i]].push_back(i);
    }
    dfs(0);
    for (int i = lcatree_shift-1; i > 0; i--) {
        lcatree[i] = min(lcatree[2*i], lcatree[2*i+1]);
    }
    assert(N < 5e5);
    for (int g = 0; g < K; g++) {
        int a = members[g][0];
        for (auto m : members[g]) {
            a = lca(a, m);
        }
        for (auto m : members[g]) {
            mergeuntil[m] = a;
        }
    }
    dfs2(0);
    bool plus1 = false;
    int lcapenible = -1;
    for (int i = 1; i < N; i++) {
        if (root(i) == i) {
            if (lcapenible == -1) lcapenible = i;
            else lcapenible = lca(lcapenible, i);
            penible[i] = true;
        }
    }
    for (int i = 1; i < N; i++) {
        if (penible[i]) {
            if (lca(lcapenible, i) == i) {
                plus1 = true;
                break;
            }
        }
    }
    int ans = dfs3(0)+plus1;
    cout << ((ans+1)>>1) << '\n';
}

Compilation message

mergers.cpp: In member function 'bool lcatreeitem::operator<(const lcatreeitem&) const':
mergers.cpp:26:9: error: 'assert' was not declared in this scope
   26 |         assert(i >= 0 && other.i >= 0 && i < N && other.i < N);
      |         ^~~~~~
mergers.cpp:5:1: note: 'assert' is defined in header '<cassert>'; did you forget to '#include <cassert>'?
    4 | #include <numeric>
  +++ |+#include <cassert>
    5 | using namespace std;
mergers.cpp:27:9: error: 'else' without a previous 'if'
   27 |         else return dist[i] < dist[other.i];
      |         ^~~~
mergers.cpp: In function 'int main()':
mergers.cpp:125:5: error: 'assert' was not declared in this scope
  125 |     assert(N < 5e5);
      |     ^~~~~~
mergers.cpp:125:5: note: 'assert' is defined in header '<cassert>'; did you forget to '#include <cassert>'?