This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include <bits/stdc++.h>
#define pb push_back
#define x first
#define y second
#define all(a) a.begin(), a.end()
#define sz(a) (int)a.size()
using namespace std;
typedef long long ll;
typedef long double ld;
typedef pair<int, int>pii;
const int maxn = 5e5 + 100;
const int mod = 1e9 + 7;
mt19937 rng(chrono::high_resolution_clock::now().time_since_epoch().count());
int n, k, anc[maxn][20], dep[maxn], par[maxn], highest[maxn];
vector<int>adj[maxn], grp[maxn], g[maxn];
void predfs(int v)
{
dep[v] = dep[anc[v][0]] + 1;
for(int i = 1; i < 20; i++)
anc[v][i] = anc[anc[v][i - 1]][i - 1];
for(int to : adj[v])
{
if(to != anc[v][0])
{
anc[to][0] = v;
predfs(to);
}
}
}
int lca(int a, int b)
{
if(dep[a] > dep[b]) swap(a, b);
for(int i = 19; i >= 0; i--)
if(dep[anc[b][i]] >= dep[a])
b = anc[b][i];
if(a == b) return a;
for(int i = 19; i >= 0; i--)
if(anc[a][i] != anc[b][i])
a = anc[a][i], b = anc[b][i];
return anc[a][0];
}
int root(int v)
{
return v == par[v] ? v : par[v] = root(par[v]);
}
void unite(int a, int b)
{
a = root(a), b = root(b);
if(a == b) return;
par[b] = a;
if(dep[highest[b]] < dep[highest[a]])
highest[a] = highest[b];
}
int main()
{
ios_base::sync_with_stdio(false), cin.tie(0);
cin >> n >> k;
for(int i = 1; i <= n; i++)
par[i] = i;
for(int i = 0; i < n - 1; i++)
{
int a, b;
cin >> a >> b;
adj[a].pb(b);
adj[b].pb(a);
}
predfs(1);
for(int i = 1; i <= n; i++)
{
int g; cin >> g;
grp[g].pb(i);
}
for(int kk = 1; kk <= k; kk++)
{
int LCA = -1;
for(int a : grp[kk])
{
if(LCA == -1)
LCA = a;
else
LCA = lca(LCA, a);
}
for(int a : grp[kk])
{
highest[root(a)] = LCA;
}
}
vector<int>ord;
for(int i = 1; i <= n; i++)
ord.pb(i);
sort(all(ord), [&](int a, int b)
{
return dep[a] > dep[b];
});
for(int v : ord)
{
if(highest[root(v)] == v)
continue;
else
unite(v, anc[v][0]);
}
for(int i = 1; i <= n; i++)
{
for(int to : adj[i])
{
if(root(i) < root(to))
{
g[root(i)].pb(root(to));
g[root(to)].pb(root(i));
}
}
}
int leaves = 0;
for(int i = 1; i <= n; i++)
if(sz(g[i]) == 1) ++leaves;
cout << (leaves + 1) / 2 << "\n";
return 0;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |