//#pragma GCC optimize("O3,unroll-loops")
//#pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
#include <bits/stdc++.h>
//#define int long long
#define ll long long
#define fi first
#define se second
#define pb push_back
#define all(lmao) lmao.begin(), lmao.end()
using namespace std;
typedef pair<int,int> pii;
typedef tuple<int,int,int> tp;
const int N = 5e5 + 5;
const ll oo = 1e16;
const int mod = 1e9 + 7;
const int base = 23;
int n, k, ans, X[N], Y[N], col[N], a[N], f[N][22], dp[N], in[N], en[N], demin, b[N], st[N << 1];
vector<int> p[N], vr[N], open[N];
bool fl[N];
void update(int i){
i += n - 1;
st[i] ++;
while(i > 1){
i /= 2;
st[i]++;
}
}
int get(int l,int r){
r++;
int ret = 0;
for(l += n - 1, r += n - 1; l < r; l /= 2, r /= 2){
if(l & 1) ret += st[l ++];
if(r & 1) ret += st[-- r];
}
return ret;
}
void pre(int u,int v){
dp[u] = ((int)p[u].size() == 1);
if(v == 0) for(int i = 0; i <= 20; i ++) f[u][i] = u;
else{
f[u][0] = v;
for(int i = 1; i <= 20; i ++) f[u][i] = f[f[u][i - 1]][i - 1];
}
in[u] = ++demin;
b[demin] = u;
for(auto j : p[u]){
if(j == v) continue;
pre(j, u);
dp[u] += dp[j];
}
en[u] = demin;
}
bool kt(int u,int v){
return in[u] <= in[v] && en[u] >= in[v];
}
int lca(int u,int v){
int kq = u;
if(kt(u, v)) return u;
else{
for(int i = 20; i >= 0; i --){
if(kt(f[u][i], v)) kq = f[u][i];
else u = f[u][i];
}
return kq;
}
}
void dfs(int u,int v){
for(auto j : p[u]){
if(j == v) continue;
dfs(j, u);
fl[u] = max(fl[j], fl[u]);
}
for(auto j : open[u]){
for(auto o : vr[j]) update(in[o]);
}
if(!fl[u] && v && get(in[u], en[u]) != (en[u] - in[u] + 1)){
ans++;
fl[u] = true;
}
}
signed main(){
ios_base::sync_with_stdio(0); cin.tie(0); cout.tie(0);
#define task "v"
if(fopen(task ".inp","r")){
freopen(task ".inp","r",stdin);
freopen(task ".out","w",stdout);
}
cin >> n >> k;
for(int i = 1; i <= n - 1; i ++){
cin >> X[i] >> Y[i];
p[X[i]].pb(Y[i]);
p[Y[i]].pb(X[i]);
}
int root = 1, tmp = 1;
for(int i = 1; i <= n; i ++){
if((int)p[i].size() > tmp) root = i, tmp = (int)p[i].size();
int x; cin >> x;
vr[x].pb(i);
col[i] = x;
}
if(n == 1){
cout << 0;
return 0;
}
if(n == 2){
if(col[1] == col[2]){
cout << 0 << "\n";
}else cout << 1 << "\n";
return 0;
}
in[n + 1] = n + 1;
pre(root, 0);
for(int i = 1; i <= k; i ++){
if(vr[i].empty()) continue;
sort(all(vr[i]), [&] (int x,int y){return in[x] < in[y];});
int mi = vr[i][0];
for(int j = 0; j < vr[i].size() - 1; j ++){
int u = lca(vr[i][j], vr[i][j + 1]);
if(in[u] < in[mi]) mi = u;
}
open[mi].pb(i);
}
dfs(root, 0);
cerr << (ans + 1) / 2 << "\n";
cout << (ans + 1) / 2;
}
Compilation message
mergers.cpp: In function 'int main()':
mergers.cpp:132:26: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
132 | for(int j = 0; j < vr[i].size() - 1; j ++){
| ~~^~~~~~~~~~~~~~~~~~
mergers.cpp:97:16: warning: ignoring return value of 'FILE* freopen(const char*, const char*, FILE*)' declared with attribute 'warn_unused_result' [-Wunused-result]
97 | freopen(task ".inp","r",stdin);
| ~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~
mergers.cpp:98:16: warning: ignoring return value of 'FILE* freopen(const char*, const char*, FILE*)' declared with attribute 'warn_unused_result' [-Wunused-result]
98 | freopen(task ".out","w",stdout);
| ~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
10 ms |
51548 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
10 ms |
51548 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
10 ms |
51548 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
74 ms |
65740 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
10 ms |
51548 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |