#include <bits/stdc++.h>
#define pb push_back
#define s second
#define f first
#define pb push_back
#define pii pair <int,int>
#define ll long long
#define int ll
using namespace std;
const int N = 5e5 + 5,inf = 1e18;
vector <pii> v[N],vec[N];
int col[N],tin[N],timer,d[25][N],dep[N];
int edge[N],val[N],n,D,I,p[N];
bool vis[N],del[N];
set <int> st[N];
void dfs(int x,int par){
d[0][x] = par;
dep[x] = dep[par] + 1;
tin[x] = ++timer;
for (auto [to, id]: v[x]){
if (to != par) {
edge[to] = id;
dfs(to,x);
}
}
}
int lca(int x,int y){
if (dep[x] < dep[y ]) swap(x,y);
for (int i = 20; i >= 0; i--)
if (dep[d[i][x]] >= dep[y]) x = d[i][x];
if (x == y) return x;
for (int i = 20; i >= 0; i--)
if (d[i][x] != d[i][y]) x = d[i][x],y = d[i][y];
return d[0][x];
}
void getpath(int x,int y){
int c = lca(x,y);
while (x != c){
val[edge[x]] = 0;
x = d[0][x];
}
while (y != c){
val[edge[y]] = 0;
y = d[0][y];
}
}
void DFS(int x,int par,int dd){
if (dd >= D) D = dd,I = x;
for (auto [to,id]: v[x]){
if (to != par) DFS(to,x,dd + val[id]);
}
}
int P(int x){
if(x==p[x])return x;
return p[x]=P(p[x]);
}
void dsu(int x,int y){
int px = P(x),py = P(y);
if (st[px].size() < st[py].size()) swap(px,py);
p[py] = px;
del[py] = 1;
set <int> nw = st[px];
for (int v: st[py])
nw.insert(v);
nw.erase(nw.find(px));
nw.erase(nw.find(py));
st[px] = nw;
st[py].clear();
}
signed main() {
ios_base::sync_with_stdio(0),cin.tie(NULL),cout.tie(NULL);
int k;
cin>>n>>k;
for (int i = 1; i < n; i++){
int a,b;
cin>>a>>b;
st[a].insert(b);
st[b].insert(a);
v[a].pb({b,i});
v[b].pb({a,i});
val[i] = 1;
}
dfs(1,1);
for (int j = 1; j <= 20 ; j++)
for (int i = 1; i <= n; i++)
d[j][i] = d[j - 1][d[j - 1][i]];
for (int i = 1; i <= n; i++)
cin >> col[i],vec[col[i]].pb({tin[i],i});
for (int i = 1; i <= k; i++){
sort(vec[i].begin(),vec[i].end());
for (int j = 1; j < vec[i].size(); j++){
getpath(vec[i][j].s,vec[i][j - 1].s);
}
}
for (int i=1;i<=n;i++)
p[i] = i;
for (int i = 2; i <= n; i++)
if (!val[edge[i]]) dsu(i,d[0][i]);
int ans=0;
for (int i = 1; i <= n; i++){
if (del[i]) continue;
ans += ((int)st[i].size() == 1);
}
cout<<(ans + 1)/2;
}
Compilation message
mergers.cpp: In function 'int main()':
mergers.cpp:113:21: warning: comparison of integer expressions of different signedness: 'long long int' and 'std::vector<std::pair<long long int, long long int> >::size_type' {aka 'long unsigned int'} [-Wsign-compare]
113 | for (int j = 1; j < vec[i].size(); j++){
| ~~^~~~~~~~~~~~~~~
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
15 ms |
49496 KB |
Output is correct |
2 |
Runtime error |
52 ms |
100180 KB |
Execution killed with signal 6 |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
15 ms |
49496 KB |
Output is correct |
2 |
Runtime error |
52 ms |
100180 KB |
Execution killed with signal 6 |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
15 ms |
49496 KB |
Output is correct |
2 |
Runtime error |
52 ms |
100180 KB |
Execution killed with signal 6 |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Execution timed out |
3060 ms |
91068 KB |
Time limit exceeded |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
15 ms |
49496 KB |
Output is correct |
2 |
Runtime error |
52 ms |
100180 KB |
Execution killed with signal 6 |
3 |
Halted |
0 ms |
0 KB |
- |