#include<bits/stdc++.h>
using namespace std;
const long long inf = (long long) 1e18 + 10;
const int inf1 = (int) 1e9 + 10;
#define int long long
#define dbl long double
#define endl '\n'
#define sc second
#define fr first
#define mp make_pair
#define pb push_back
#define all(x) x.begin(), x.end()
const int maxn = 5e5+10;
int n, k, tin[maxn], p[maxn][23], ps[maxn][23], h[maxn], ds[maxn], dsz[maxn];
int timer = 0;
vector<int> g[maxn];
vector<pair<int,int>> vec[maxn];
void dfs(int u, int ant) {
p[u][0] = ant;
tin[u] = ++timer;
for(int i = 1; i <= 20; i++) {
p[u][i] = p[p[u][i-1]][i-1];
}
for(auto v : g[u]) {
if(v == ant) continue;
h[v] = h[u]+1;
dfs(v,u);
}
}
int lca(int u, int v) {
if(h[u] < h[v]) swap(u,v);
for(int i = 20; i >= 0; i--) {
if(h[p[u][i]] >= h[v]) {
u = p[u][i];
}
}
if(u == v) return u;
for(int i = 20; i >= 0; i--) {
if(p[u][i] != p[v][i]) {
u = p[u][i];
v = p[v][i];
}
}
return p[u][0];
}
int find(int v) {
if(ds[v] == v) return v;
return ds[v] = find(ds[v]);
}
void join(int u, int v) {
if(dsz[u] < dsz[v]) swap(u,v);
ds[v] = u;
dsz[u]+= dsz[v];
}
int qtdf = 0, mark[maxn];
void dfs1(int u) {
mark[u] = 1;
set<int> st;
for(auto v : g[u]) {
v = find(v);
st.insert(v);
if(mark[v]) continue;
dfs1(v);
}
if(st.size() == 1) qtdf++;
}
void solve() {
cin >> n >> k;
for(int i = 1; i <= n-1; i++) {
int u,v; cin >> u >> v;
g[u].pb(v);
g[v].pb(u);
}
dfs(1,1);
for(int i = 1; i <= n; i++) {
int cl; cin >> cl;
vec[cl].pb(mp(tin[i],i));
}
for(int i = 1; i <= k; i++) {
sort(all(vec[i]));
vec[i].pb(vec[i][0]);
for(int j = 1; j < vec[i].size(); j++) {
int u = vec[i][j-1].sc;
int v = vec[i][j].sc;
int lc = lca(u,v);
for(int j = 20; j >= 0; j--) {
if(h[p[u][j]] > h[lc]) {
ps[u][j]++;
u = p[u][j];
}
}
if(u != lc) {
ps[u][0]++;
}
for(int j = 20; j >= 0; j--) {
if(h[p[v][j]] > h[lc]) {
ps[v][j]++;
v = p[v][j];
}
}
if(v != lc) {
ps[v][0]++;
}
}
}
for(int j = 20; j >= 1; j--) {
for(int i = 1; i <= n; i++) {
ps[i][j-1]+= ps[i][j];
ps[p[i][j-1]][j-1]+= ps[i][j];
}
}
for(int i = 1; i <= n; i++) {
ds[i] = i;
dsz[i] = 1;
}
int ans = n-1;
for(int i = 1; i <= n; i++) {
if(ps[i][0] != 0) {
int u = find(i);
int v = find(p[i][0]);
if(u != v) {
ans--;
join(u,v);
}
}
}
for(int i = 1; i <= n; i++) {
int v = find(i);
if(v == i) continue;
for(auto x : g[i]) {
if(x != v) g[v].pb(find(x));
}
}
dfs1(find(1));
if(ans == 0) cout << 0 << endl;
else cout << (qtdf+1)/2 << endl;
}
int32_t main() {
ios::sync_with_stdio(false); cin.tie(0);
// freopen("in.in", "r", stdin);
// freopen("out.out", "w", stdout);
int tt = 1;
// cin >> tt;
int t = 1;
while(tt--) {
solve();
}
}
Compilation message
mergers.cpp: In function 'void solve()':
mergers.cpp:100:26: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<std::pair<long long int, long long int> >::size_type' {aka 'long unsigned int'} [-Wsign-compare]
100 | for(int j = 1; j < vec[i].size(); j++) {
| ~~^~~~~~~~~~~~~~~
mergers.cpp: In function 'int32_t main()':
mergers.cpp:175:9: warning: unused variable 't' [-Wunused-variable]
175 | int t = 1;
| ^
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
12 ms |
23764 KB |
Output is correct |
2 |
Incorrect |
13 ms |
23912 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
12 ms |
23764 KB |
Output is correct |
2 |
Incorrect |
13 ms |
23912 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
12 ms |
23764 KB |
Output is correct |
2 |
Incorrect |
13 ms |
23912 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
164 ms |
71116 KB |
Output is correct |
2 |
Correct |
194 ms |
78012 KB |
Output is correct |
3 |
Incorrect |
18 ms |
25368 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
12 ms |
23764 KB |
Output is correct |
2 |
Incorrect |
13 ms |
23912 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |