#include<bits/stdc++.h>
#define ll long long
#define pii pair<int,int>
using namespace std;
const int mxn=5e5+5;
const int inf=1e9;
vector<int> v;
vector<int> adj[mxn];
int S[mxn];
vector<int> con[mxn];
int level[mxn];
int dp[mxn][20];
struct DSU {
int parent[mxn];
void init() {
for(int i = 1; i < mxn; i++) parent[i] = i;
}
int findrep(int u) {
return parent[u] == u ? u : parent[u] = findrep(parent[u]);
}
void unite(int u,int v) {
u = findrep(u);
v = findrep(v);
if(u == v) return;
if(level[u] < level[v]) swap(u,v);
parent[u] = v;
}
} dsu1,dsu2;
bool dfs(int cur,int prev,int src) {
if(cur == src) return true;
bool f = false;
for(int u : adj[cur]) {
if(u == prev) continue;
f |= dfs(u,cur,src);
}
if(f) v.push_back(cur);
return f;
}
void DFS(int cur,int prev) {
level[cur] = level[prev] + 1;
dp[cur][0] = prev;
for(int i = 1; i < 20; i++) dp[cur][i] = dp[dp[cur][i - 1]][i - 1];
for(int u : adj[cur]) {
if(u != prev) {
DFS(u,cur);
if(S[u] == S[cur]) dsu1.unite(u,cur);
}
}
}
int findLCA(int u,int v){
if(level[u] < level[v])swap(u,v);
int diff = level[u] - level[v];
for(int i = 0; i < 20; i++) if((diff>>i) & 1) u = dp[u][i];
if(u == v) return u;
for(int i = 19; i >= 0; i--) {
if(dp[u][i] != dp[v][i]) {
u = dp[u][i];
v = dp[v][i];
}
}
return dp[u][0];
}
void process(int u,int lv) {
int cu = dsu1.findrep(u);
int clv = level[cu];
while(clv > lv) {
int ccu = dsu1.findrep(dp[cu][0]);
dsu1.unite(cu,ccu);
dsu2.unite(S[cu],S[ccu]);
cu = ccu;
clv = level[cu];
}
}
int main() {
int n,k;
cin >> n >> k;
for(int i = 1; i < n; i++) {
int u,v;
cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
for(int i = 1; i <= n; i++) {
cin >> S[i];
con[S[i]].push_back(i);
}
dsu1.init();
dsu2.init();
for(int i = 1; i <= k; i++) {
sort(con[i].begin(),con[i].end(),[](int a,int b) {
return level[a] < level[b];
});
for(int j = 1; j < con[i].size(); j++) {
v.clear();
int u = con[i][j];
int v = con[i][j - 1];
int lca = findLCA(u,v);
process(u,level[lca]);
process(v,level[lca]);
}
}
int deg[k + 1] = {};
for(int i = 1; i <= n; i++) {
int x = dsu2.findrep(S[i]);
for(int u : adj[i]) {
if(x == dsu2.findrep(S[u])) continue;
deg[dsu2.findrep(S[u])]++;
}
}
int leaf = 0;
for(int i = 1; i <= k; i++) {
if(dsu2.parent[i] == i && deg[i] == 1) leaf++;
}
cout<<(leaf + 1) / 2<<endl;
}
Compilation message
mergers.cpp: In function 'int main()':
mergers.cpp:112:26: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
112 | for(int j = 1; j < con[i].size(); j++) {
| ~~^~~~~~~~~~~~~~~
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
18 ms |
27756 KB |
Output is correct |
2 |
Incorrect |
18 ms |
27756 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
18 ms |
27756 KB |
Output is correct |
2 |
Incorrect |
18 ms |
27756 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
18 ms |
27756 KB |
Output is correct |
2 |
Incorrect |
18 ms |
27756 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
130 ms |
33676 KB |
Output is correct |
2 |
Correct |
157 ms |
37092 KB |
Output is correct |
3 |
Incorrect |
21 ms |
27884 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
18 ms |
27756 KB |
Output is correct |
2 |
Incorrect |
18 ms |
27756 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |