제출 #796693

#제출 시각아이디문제언어결과실행 시간메모리
796693rnl42Mergers (JOI19_mergers)C++14
70 / 100
303 ms161880 KiB
#include <iostream>
#include <vector>
#include <algorithm>
#include <numeric>
#include <cassert>
using namespace std;
#define int long long

const int MAXN = 5e5;

int N, K;
int dist[MAXN];
vector<int> adj[MAXN];
int group_of[MAXN];
vector<int> members[MAXN];
struct lcatreeitem {
    int i;
    lcatreeitem& operator=(int other) {
        i = other;
        return *this;
    }
    operator int() const {
        return i;
    }
    bool operator<(const lcatreeitem& other) const {
        if (i == -1) return false;
        else if (other.i == -1) return true;
        return dist[i] < dist[other.i];
    }
} lcatree[1<<20];
int first[MAXN];
int lcatree_i = 0;
const int lcatree_shift = 1<<19;
int mergeuntil[MAXN];
bool penible[MAXN];

int uf[MAXN];

int root(int u) {
    return uf[u] == u ? u : uf[u] = root(uf[u]);
}
void merge(int u, int v) {
    if (dist[u] > dist[v]) {
        swap(u, v);
    }
    uf[root(v)] = root(u);
}

void dfs(int u) {
    first[u] = lcatree_i;
    lcatree[lcatree_shift+lcatree_i++] = u;
    for (int v : adj[u]) {
        adj[v].erase(find(adj[v].begin(), adj[v].end(), u));
        dist[v] = dist[u]+1;
        dfs(v);
        lcatree[lcatree_shift+lcatree_i++] = u;
    }
}

int dfs2(int u) {
    int ret = mergeuntil[u];
    for (int v : adj[u]) {
        int r = dfs2(v);
        if (dist[r] < dist[ret]) ret = r;
        if (r != v) {
            merge(v, u);
        }
    }
    return ret;
}

int lca(int u, int v) {
    assert(u >= 0 && v >= 0 && u < N && v < N);
    if (first[u] > first[v]) swap(u, v);
    int l = lcatree_shift+first[u], r = lcatree_shift+first[v]+1;
    lcatreeitem ret;
    ret = -1;
    for (; l < r; l >>= 1, r >>= 1) {
        if (l&1) {
            assert(l >= 0 && l < 1e6);
            ret = min(ret, lcatree[l++]);
        }
        if (r&1) {
            assert(r >= 0 && r < 1e6);
            ret = min(ret, lcatree[--r]);
        }
    }
    return ret;
}

int dfs3(int u) {
    int ret = 0;
    for (int v : adj[u]) {
        ret += dfs3(v);
    }
    ret = max(ret, (int)penible[u]);
    return ret;
}

signed main() {
    ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
    cin >> N >> K;
    iota(mergeuntil, mergeuntil+N, 0);
    iota(uf, uf+N, 0);
    int u, v;
    for (int i = 0; i < N-1; i++) {
        cin >> u >> v, u--, v--;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    for (int i = 0; i < N; i++) {
        cin >> group_of[i], group_of[i]--;
        members[group_of[i]].push_back(i);
    }
    dfs(0);
    for (int i = lcatree_shift-1; i > 0; i--) {
        lcatree[i] = min(lcatree[2*i], lcatree[2*i+1]);
    }
    for (int g = 0; g < K; g++) {
        int a = members[g][0];
        for (auto m : members[g]) {
            a = lca(a, m);
        }
        for (auto m : members[g]) {
            mergeuntil[m] = a;
        }
    }
    dfs2(0);
    bool plus1 = false;
    int lcapenible = -1;
    for (int i = 1; i < N; i++) {
        if (root(i) == i) {
            if (lcapenible == -1) lcapenible = i;
            else lcapenible = lca(lcapenible, i);
            penible[i] = true;
        }
    }
    for (int i = 1; i < N; i++) {
        if (penible[i]) {
            if (lca(lcapenible, i) == i) {
                plus1 = true;
                break;
            }
        }
    }
    int ans = dfs3(0)+plus1;
    cout << ((ans+1)>>1) << '\n';
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...