제출 #128932

#제출 시각아이디문제언어결과실행 시간메모리
128932kuroniMergers (JOI19_mergers)C++17
100 / 100
1566 ms64072 KiB
#include <bits/stdc++.h>
using namespace std;

const int N = 500005, K = 500005;

int n, k, u, v, ans = 0, a[N], sub[N], f[N];
int cur = 0, sum[K], cnt[K];
vector<int> adj[N];

void add(int u, int p, int val, int av = 0)
{
    cur -= (cnt[a[u]] == sum[a[u]] || cnt[a[u]] == 0);
    cnt[a[u]] += val;
    cur += (cnt[a[u]] == sum[a[u]] || cnt[a[u]] == 0);
    for (int &v : adj[u])
        if (v != p && v != av)
            add(v, u, val, av);
}

int DFS_1(int u, int p = 0)
{
    sub[u] = 1;
    for (int &v : adj[u])
        if (v != p)
            sub[u] += DFS_1(v, u);
    return sub[u];
}

bool DFS_2(int u, int p, bool big)
{
    int mc = 0;
    for (int &v : adj[u])
        if (v != p && sub[v] > sub[mc])
            mc = v;
    for (int &v : adj[u])
        if (v != p && v != mc)
            f[u] += (DFS_2(v, u, false) ? 1 : f[v]);
    if (mc != 0)
        f[u] += (DFS_2(mc, u, true) ? 1 : f[mc]);
    add(u, p, 1, mc);
    bool chk = (cur == 0);
    if (chk)
        ans += (f[u] == (u == 1));
    if (!big)
        add(u, p, -1);
    return chk;
}

int main()
{
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    cin >> n >> k;
    for (int i = 1; i < n; i++)
    {
        cin >> u >> v;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    for (int i = 1; i <= n; i++)
    {
        cin >> a[i];
        sum[a[i]]++;
    }
    DFS_1(1);
    DFS_2(1, 0, true);
    cout << (ans + 1) / 2;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...