Submission #1210843

#TimeUsernameProblemLanguageResultExecution timeMemory
1210843badge881Mergers (JOI19_mergers)C++20
100 / 100
814 ms224664 KiB
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int maxn = 5e5 + 5, maxk = 5e5+5, lg = 20;
int n, k;
vector<int> ADJ[maxn];
int A[maxn];

int SZ[maxn], FA[maxn], newid[maxn], lim[maxn], cnt = 0;
bool cmp(int x, int y) {
    return SZ[x] > SZ[y];
}
void DFS0(int u) {
    SZ[u] = 1;
    for (int v:ADJ[u]) {
        FA[v] = u;
        ADJ[v].erase(find(ADJ[v].begin(), ADJ[v].end(), u));
        DFS0(v);
        SZ[u] += SZ[v];
    }
    sort(ADJ[u].begin(), ADJ[u].end(), cmp);
}
void DFS1(int u) {
    cnt++;
    newid[u] = cnt;
    for (int v:ADJ[u]) DFS1(v);
    lim[newid[u]] = cnt;
}

vector<int> adj[maxn];
int fa[maxn], sz[maxn];
vector<int> mem[maxk];
int a[maxn];

int st[maxn][lg];
void dfs0(int u) {
    for (int i=1;i<lg;i++) {
        if (!st[u][i-1] || !st[st[u][i-1]][i-1]) break;
        st[u][i] = st[st[u][i-1]][i-1];
    }
    for (int v:adj[u]) {
        st[v][0] = u;
        dfs0(v);
    }
}

int psum[maxn];
int lca(int u, int c) {
    int l = mem[c][0], r = mem[c].back();
    for (int i=lg-1;i>=0;i--) if (st[u][i]) {
        int v = st[u][i];
        if (!(v<=l && r<=lim[v])) u = v;
    }
    if (!(u<=l && r<=lim[u])) u = fa[u];
    return u;
}

bool sep[maxn], sepleaf[maxn];
int sum[maxn];

signed main() {
    ios::sync_with_stdio(0); cin.tie(0);
    cin >> n >> k;
    for (int i=1;i<=n-1;i++) {
        int u, v;
        cin >> u >> v;
        ADJ[u].push_back(v), ADJ[v].push_back(u);
    }
    for (int i=1;i<=n;i++) cin >> A[i];
    DFS0(1);
    DFS1(1);
    // for (int i=1;i<=n;i++) cout << newid[i] << " "; cout << endl;
    // for (int i=1;i<=n;i++) cout << lim[i] << " "; cout << endl;
    for (int i=1;i<=n;i++) {
        for (int j:ADJ[i]) adj[newid[i]].push_back(newid[j]);
        if (i!=1) fa[newid[i]] = newid[FA[i]];
        else fa[i] = 0;
        sz[newid[i]] = SZ[i];
    }
    for (int i=1;i<=n;i++) {
        a[newid[i]] = A[i];
        mem[A[i]].push_back(newid[i]);
    }
    dfs0(1);
    for (int i=1;i<=k;i++) {
        sort(mem[i].begin(), mem[i].end());
        psum[lca(mem[i][0], i)] += mem[i].size();
    }
    for (int i=1;i<=n;i++) psum[i] += psum[i-1];

    for (int i=2;i<=n;i++) {
        sep[i] = (psum[lim[i]] - psum[i-1] == lim[i] - i + 1);
        // cout << psum[lim[i]] - psum[i-1] << " " << lim[i] - i + 1 << endl;
    }
    // for (int i=1;i<=n;i++) cout << sep[i]; cout << endl;
    
    for (int i=1;i<=n;i++) sum[i] = sum[i-1] + sep[i];

    for (int i=1;i<=n;i++) sepleaf[i] = (sep[i] && sum[lim[i]] - sum[i-1] == 1);

    bool add = false;
    for (int i=2;i<=n;i++) if (sum[lim[i]] - sum[i-1] == sum[n] && sep[i] && !sepleaf[i]) add = true;

    int ans = 0;
    for (int i=1;i<=n;i++) ans += sepleaf[i];
    if (ans%2==1) cout << (ans+1)/2 << endl;
    else cout << ans/2 + add << endl;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...