Submission #1143800

#TimeUsernameProblemLanguageResultExecution timeMemory
1143800PersiaCapital City (JOI20_capital_city)C++17
11 / 100
1610 ms548360 KiB
#include <bits/stdc++.h>

using namespace std;

#define bit(i, x) (x >> i & 1)
#define ll long long
#define sz(x) (int)x.size()

const int N = 2e5 + 5;
const int mod = 998244353;

mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
ll rnd(ll l, ll r) {
    return uniform_int_distribution<ll>(l, r)(rng);
}

int n, k;
vector<int> G[N];
int c[N];

// lca
int f[N][19], g[N][19], cnt2;
int h[N];

// newgraph
vector<int> G2[19 * N];

// additional
vector<int> b[N];

// scc
int num[19 * N], low[19 * N], cnt;
int col[19 * N], ct;
int val[19 * N];
vector<int> stk;
int bad[19 * N];
bool mark[19 * N];
vector<int> G3[19 * N];

// result
int res;

// debugging stuff

void addedge(int x, int y) {
    G2[x].push_back(y);

    // if(y == 1) 
    // cerr << x << " " << y << "\n";
}

void addedge2(int x, int p) {
    assert(h[x] >= h[p]);
    int d = h[x] - h[p];
    int u = c[x];
    for(int i = 0; i <= 17; i++) if(bit(i, d)) {
        addedge(u, g[x][i]);
        x = f[x][i];
    }
    addedge(u, c[p]);
}

void predfs(int u = 1, int par = -1) {
    for(int i = 1; i <= 17; i++) if(f[f[u][i - 1]][i - 1]) {
        f[u][i] = f[f[u][i - 1]][i - 1];
    }
    for(int i = 0; i <= 17; i++) {
        g[u][i] = ++cnt2;
        if(i > 0) {
            if(f[u][i - 1]) addedge(g[u][i], g[u][i - 1]);
            if(f[f[u][i - 1]][i - 1]) addedge(g[u][i], g[f[u][i - 1]][i - 1]);
        }
    }
    addedge(g[u][0], c[u]);
    for(int v : G[u]) if(v != par) {
        h[v] = h[u] + 1;
        f[v][0] = u;
        predfs(v, u);
    }
}

int lca(int x, int y) {
    if(h[x] < h[y]) swap(x, y);
    int d = h[x] - h[y];
    for(int i = 0; i <= 17; i++) if(bit(i, d)) x = f[x][i];
    if(x == y) return x;
    for(int i = 17; i >= 0; i--) if(f[x][i] != f[y][i]) {
        x = f[x][i];
        y = f[y][i];
    }
    return f[x][0];
}

void dfs(int u) {
    num[u] = low[u] = ++cnt;
    stk.push_back(u);
    for(int v : G2[u]) if(!col[v]) {
        if(num[v]) low[u] = min(low[u], num[v]);
        else {
            dfs(v);
            low[u] = min(low[u], low[v]);
        }
    }
    if(num[u] == low[u]) {
        ct++;
        while(1) {
            int top = stk.back();
            stk.pop_back();
            col[top] = ct;
            val[ct] += (top <= k);

            if(top == u) break;
        }
    }
}

int dp(int u) {
    if(mark[u]) return bad[u];
    mark[u] = 1;
    for(int v : G3[u]) {
        bad[u] |= dp(v);
    }
    return bad[u];
}

signed main(int argc, char* argv[]) {
    ios_base::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);

    cin >> n >> k;
    for(int i = 1; i < n; i++) {
        int x, y; cin >> x >> y;
        G[x].push_back(y);
        G[y].push_back(x);
    }
    for(int i = 1; i <= n; i++) cin >> c[i];

    cnt2 = k;
    predfs();
    for(int i = 1; i <= n; i++) b[c[i]].push_back(i);
    for(int i = 1; i <= n; i++) if(!b[i].empty()) {
        int u = b[i].front();
        for(int j : b[i]) u = lca(u, j);
        for(int j : b[i]) addedge2(j,  u);
    }

    // addedge2(2, 1);

    for(int i = 1; i <= cnt2; i++) if(!num[i]) dfs(i);
    for(int i = 1; i <= cnt2; i++) {
        for(int j : G2[i]) if(col[i] != col[j]) {
            bad[col[i]] |= (val[col[j]] > 0);
            G3[col[i]].push_back(col[j]);
        }
    }
    res = k - 1;
    for(int i = 1; i <= ct; i++) {
        if(!dp(i) && val[i]) res = min(res, val[i] - 1);

        // if(val[i] > 0) {
        //     if(res > val[i]) {
        //         res = val[i];
        //         cout << i << ": ";
        //         for(int j : adj[i]) cout << j << " ";
        //         cout << "\n";
        //     }
        // }

        // if(adj[i].empty()) continue;
        // cout << i << ": ";
        // for(int j : adj[i]) cout << j << " ";
        // cout << "\n";
        // cout << dp(i) << " " << val[i] << "\n";
        // cout << "\n";
    }
    cout << res; 
    // for(int i = 1; i <= k; i++) cout << col[i] << " ";

    // debug
    // for(int i = 1; i <= n; i++) {
    //     for(int j = i; j <= n; j++) {
    //         cout << i << " " << j << " " << lca(i, j) << "\n";
    //     }
    // }

    // cout << g[2][0] << " " << g[5][1];

    return 0 ^ 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...