답안 #604387

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
604387 2022-07-25T05:29:13 Z talant117408 Mergers (JOI19_mergers) C++17
0 / 100
347 ms 55952 KB
#include <bits/stdc++.h>
 
using namespace std;
 
typedef long long ll;
typedef pair <int, int> pii;
typedef pair <ll, ll> pll;

#define long                unsigned long 
#define pb                  push_back
#define mp                  make_pair
#define all(v)              (v).begin(),(v).end()
#define rall(v)             (v).rbegin(),(v).rend()
#define lb                  lower_bound
#define ub                  upper_bound
#define sz(v)               int((v).size())
#define do_not_disturb      ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);
#define endl                '\n'
#define PI                  2*acos(0.0)

const int N = 5e5+7;
vector <int> graph[N], comps_graph[N];
int color[N], states[N], comp[N], num;
int tin[N], tout[N], timer;
int color_lca[N];
int up[N][20];
int n;
map <pii, bool> separators;

// LCA
bool upper(int a, int b) {
    return tin[a] <= tin[b] && tout[a] >= tout[b];
}

int get_lca(int a, int b) {
    if (upper(a, b)) return a;
    if (upper(b, a)) return b;
    for (int i = 19; i >= 0; i--) {
        if (!upper(up[a][i], b)) {
            a = up[a][i];
        }
    }
    return up[a][0];
}

// HLD
int heavy[N], head[N], pos[N];

int dfs(int v, int p) {
    heavy[v] = -1;
    tin[v] = ++timer;
    up[v][0] = p;
    for (int i = 1; i < 20; i++) {
        up[v][i] = up[up[v][i-1]][i-1];
    }
    int saizu = 1, mx_saizu = 0;
    
    for (auto to : graph[v]) {
        if (to == p) continue;
        auto res = dfs(to, v);
        saizu += res;
        if (res > mx_saizu) {
            mx_saizu = res;
            heavy[v] = to;
        }
    }
    tout[v] = ++timer;
    return saizu;
}

void decompose(int v, int h) {
    pos[v] = ++timer;
    head[v] = h;
    if (heavy[v] != -1) {
        decompose(heavy[v], h);
    }
    for (auto to : graph[v]) {
        if (to != up[v][0] && to != heavy[v]) {
            decompose(to, to);
        }
    }
}

int tree[N*4], lz[N*4];

void push(int v, int l, int r) {
    if (lz[v] != 0) {
        tree[v] += (r - l + 1) * lz[v];
        if (l != r) {
            lz[v * 2] += lz[v];
            lz[v * 2 + 1] += lz[v];
        }
        lz[v] = 0;
    }
}

void update(int v, int l, int r, int ql, int qr, int val) {
    push(v, l, r);
    if (ql > r || l > qr) return ;
    if (ql <= l && r <= qr) {
        lz[v] += val;
        push(v, l, r);
        return;
    }
    int mid = (l + r) >> 1;
    update(v*2, l, mid, ql, qr, val);
    update(v*2+1, mid+1, r, ql, qr, val);
    tree[v] = tree[v*2] + tree[v*2+1];
}

int get(int v, int l, int r, int ql, int qr) {
    push(v, l, r);
    if (ql > r || l > qr) return 0;
    if (ql <= l && r <= qr) return tree[v];
    int mid = (l + r) >> 1;
    return get(v*2, l, mid, ql, qr) + get(v*2+1, mid+1, r, ql, qr);
}

void hld_update(int v, int origin) {
    for (; head[v] != head[origin]; v = up[head[v]][0]) {
        update(1, 1, n, pos[head[v]], pos[v], 1);
    }
    update(1, 1, n, pos[origin], pos[v], 1);
    update(1, 1, n, pos[origin], pos[origin], -1);
}

void find_states(int v, int p) {
    if (get(1, 1, n, pos[v], pos[v]) == 0) {
        separators[mp(min(v, p), max(v, p))] = 1;
    }
    for (auto to : graph[v]) {
        if (to == p) continue;
        find_states(to, v);
    }
}

void find_comps(int v, int p) {
    comp[v] = num;
    for (auto to : graph[v]) {
        if (to == p || separators[mp(min(v, to), max(v, to))]) continue;
        find_comps(to, v);
    }
}

void solve(int test) {
    int k;
    cin >> n >> k;
    for (int i = 0; i < n - 1; i++) {
        int a, b;
        cin >> a >> b;
        graph[a].pb(b);
        graph[b].pb(a);
    }
    for (int i = 1; i <= n; i++) {
        cin >> color[i];
        states[color[i]]++;
    }
    dfs(1, 1);
    timer = 0;
    decompose(1, 1);
    
    vector <pii> LCAS[k + 1];
    for (int i = 1; i <= n; i++) {
        LCAS[color[i]].pb(mp(tin[i], i));
    }
    for (int i = 1; i <= k; i++) {
        sort(all(LCAS[i]));
        int tmp = LCAS[i][0].second;
        for (int j = 1; j < sz(LCAS[i]); j++) {
            tmp = get_lca(tmp, LCAS[i][j].second);
        }
        color_lca[i] = tmp;
    }
    
    for (int i = 1; i <= n; i++) {
        hld_update(i, color_lca[color[i]]);
    }
    
    find_states(1, 0);
    for (int i = 1; i <= n; i++) {
        if (!comp[i]) {
            num++;
            find_comps(i, i);
        }
    }
    for (int i = 1; i <= n; i++) {
        for (auto to : graph[i]) {
            if (comp[i] != comp[to]) {
                comps_graph[comp[i]].pb(comp[to]);
            }
        }
    }
    int ans = -1;
    for (int i = 1; i <= num; i++) {
        if (sz(comps_graph[i]) == 1 || sz(comps_graph[i]) == 0) ans++;
    }
    cout << ans << endl;
}  

int main() {
    //~ do_not_disturb
    
    int t = 1;
    //~ cin >> t;
    for (int i = 1; i <= t; i++) {
        solve(i);
    }
    
    return 0;
}
/*
11 6

1 2
2 4
4 3
2 5
5 6
2 8
8 7
7 9
7 10 
10 11
5
6
4
3
4
5
2
1
2
1
2
*/
# 결과 실행 시간 메모리 Grader output
1 Incorrect 12 ms 23796 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 12 ms 23796 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 12 ms 23796 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 256 ms 47072 KB Output is correct
2 Incorrect 347 ms 55952 KB Output isn't correct
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 12 ms 23796 KB Output isn't correct
2 Halted 0 ms 0 KB -