//#pragma GCC target("avx2")
#pragma GCC optimize("O3")
#include <iostream>
#include <vector>
#include <algorithm>
#include <math.h>
#include <set>
#include <stack>
#include <iomanip>
#include <bitset>
#include <map>
#include <cassert>
#include <array>
#include <queue>
#include <cstring>
#include <random>
#include <unordered_set>
#include <unordered_map>
#define pqueue priority_queue
#define pb(x) push_back(x)
// #define endl '\n'
#define all(x) x.begin(), x.end()
//#define int long long
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef long double ld;
typedef vector<int> vi;
typedef vector<vector<int> > vvi;
// typedef tuple<ll, ll, ll> tiii;
typedef pair<int, int> pii;
typedef vector<pair<int, int> > vpii;
typedef vector<bool> vb;
typedef vector<string> vs;
typedef vector<char> vc;
const int inf = 1e9;
const ll mod = 1e9 + 7;
const ll mod2 = 998244353;
const ld eps = 1e-14;
void fast_io(){
ios_base::sync_with_stdio(0);
cin.tie(0);
// freopen("inputik.txt", "r", stdin);
// freopen("outputik.txt", "w", stdout);
}
const int maxn = 5e5 + 228;
vi g[maxn];
int up[maxn][20], a[maxn], tin[maxn], tout[maxn], link[maxn], sizex[maxn];
int timer = 0;
void unite(int a, int b);
void dfs(int v, int p = 0){
up[v][0] = p;
tin[v] = timer++;
for(int i=1; i<20; i++){
up[v][i] = up[up[v][i-1]][i-1];
}
for(int i:g[v]){
if(i != p){
dfs(i, v);
}
}
tout[v] = ++timer;
}
void dfs2(int v, int p = -1){
for(int i:g[v]){
if(i != p){
dfs2(i, v);
a[v] += a[i];
}
}
if(a[v])
unite(v, p);
}
bool upper(int a, int b){
return tin[a] <= tin[b] && tout[b] <= tout[a];
}
int lca(int a, int b){
if(upper(a, b))
return a;
if(upper(b, a))
return b;
int cur = a;
for(int i=19; i>=0; i--){
if(!upper(up[cur][i], b)){
cur = up[cur][i];
}
}
cur = up[cur][0];
return cur;
}
int find(int a){
return (link[a] == a ? a : link[a] = find(link[a]));
}
void unite(int a, int b){
// cout << 1 << endl;
a = find(a);
b = find(b);
if(a == b)
return;
if(sizex[a] > sizex[b])
swap(a, b);
link[a] = b;
sizex[b] += sizex[a];
}
void solve(){
int n, k;
cin >> n >> k;
vvi col(k);
for(int i=0; i+1<n; i++){
int a, b;
cin >> a >> b;
a--, b--;
g[a].pb(b);
g[b].pb(a);
}
for(int i=0; i<n; i++){
int a;
cin >> a;
col[a-1].pb(i);
}
dfs(0);
for(int i=0; i<k; i++){
int cur = col[i][0];
for(int j=1; j<col[i].size(); j++){
cur = lca(cur, col[i][j]);
}
// cout << i << " " << cur << endl;
for(int j:col[i]){
a[j]++;
a[cur]--;
}
}
for(int i=0; i<maxn; i++){
link[i] = i;
sizex[i] = 1;
}
dfs2(0);
int ans = 0;
vi cnt(n, 0);
for(int i=0; i<n; i++){
for(int j:g[i]){
if(find(i) != find(j)){
cnt[find(i)]++;
cnt[find(j)]++;
}
}
}
for(int i=0; i<n; i++){
if(cnt[i] == 2 && find(i) == i){
// cout << 1 << endl;
ans++;
}
}
// cout << endl;
// cout << ans << endl;
cout << max((ans+1)/2, 0) << endl;
}
/*
5 4
1 2
2 3
3 4
3 5
1
2
1
3
4
=======
5 4
1 2
2 3
3 4
4 5
1
2
3
4
1
=======
2 2
1 2
1
2
*/
signed main(){
fast_io();
// srand(time(NULL));
cout << fixed << setprecision(10);
int q = 1;
// cin >> q;
while(q--)
solve();
}
Compilation message
capital_city.cpp: In function 'void solve()':
capital_city.cpp:138:23: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
138 | for(int j=1; j<col[i].size(); j++){
| ~^~~~~~~~~~~~~~
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
10 ms |
15948 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
10 ms |
15948 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
336 ms |
65352 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
10 ms |
15948 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |