This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
/* IN THE NAME OF GOD */
/* |\/| /-\ |\| | |\/| /-\ */
#include "bits/stdc++.h"
using namespace std;
#define sz(x) (int)x.size()
#define endl '\n'
#define pb push_back
#define all(x) x.begin(), x.end()
#define F first
#define S second
#define mr make_pair
#define int long long
#define pii pair<int, int>
typedef long double ld;
typedef long long ll;
mt19937 rng (chrono::high_resolution_clock::now().time_since_epoch().count());
const int N = 5e5 + 10;
const ll MOD = 1e9 + 9;
const ll inf = 2e17;
const ll INF = 7e15;
vector<int> g[N];
int cnt[N], s[N], sum[N], k, T, par[N][20], h[N], tin[N], tout[N], f[N], d[N];
vector<int> col[N];
void dfs(int v, int p){
for(int i = 1; i < 20; i++)
par[v][i] = par[par[v][i - 1]][i - 1];
tin[v] = ++T;
for(int u : g[v]){
if(u == p)
continue;
par[u][0] = v;
h[u] = h[v] + 1;
dfs(u, v);
}
tout[v] = ++T;
}
int lca(int u, int v){
if(h[u] > h[v])
swap(u, v);
int H = h[v] - h[u];
for(int i = 0; i < 20; i++){
if(H & (1 << i))
v = par[v][i];
}
if(v == u)
return v;
for(int i = 19; i >= 0; i--){
if(par[v][i] != par[u][i])
v = par[v][i], u = par[u][i];
}
return par[v][0];
}
void moorgh(int v, int p){
for(int u : g[v]){
if(u == p)
continue;
moorgh(u, v);
sum[v] += sum[u];
}
sum[v] += f[v];
}
int32_t main(){
ios_base:: sync_with_stdio(0), cin.tie(0), cout.tie(0);
int n;
cin >> n >> k;
int u, v;
for(int i = 1; i < n; i++){
cin >> u >> v;
g[u].pb(v);
g[v].pb(u);
}
for(int i = 1; i <= n; i++){
cin >> s[i];
col[s[i]].pb(i);
cnt[s[i]]++;
}
dfs(1, 0);
vector<pii> st;
for(int i = 1; i <= k; i++){
if(sz(col[i]) == 0)
continue;
int root = col[i][0];
for(int v : col[i])
root = lca(root, v);
for(int v : col[i]){
f[root]--;
f[v]++;
}
}
moorgh(1, 0);
vector<pii> vec;
for(int i = 2; i <= n; i++){
if(sum[i] == 0){
vec.pb(mr(tin[i], i));
vec.pb(mr(tin[par[i][0]], par[i][0]));
}
}
sort(all(vec));
vec.resize(unique(all(vec)) - vec.begin());
st.clear();
for(pii v : vec)
st.pb(v);
for(int i = 0; i < sz(vec) - 1; i++)
st.pb(mr(tin[lca(vec[i].S, vec[i + 1].S)], lca(vec[i].S, vec[i + 1].S)));
sort(all(st));
st.resize(unique(all(st)) - st.begin());
vector<int> V;
for(pii v : st) {
while(!V.empty() && tout[V.back()] < tout[v.S])
V.pop_back();
if(!V.empty()){
d[v.S]++;
d[V.back()]++;
}
V.pb(v.S);
}
int t = 0;
for(pii v : st){
if(d[v.S] == 1)
t++;
}
cout << t / 2 + (t % 2);
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |