#include <bits/stdc++.h>
#define ff first
#define ss second
#define pb push_back
#define mp make_pair
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef long double ld;
typedef pair<ll, ll> pll;
typedef pair<ull, ull> pull;
typedef pair<int, int> pii;
typedef pair<ld, ld> pld;
int n, k;
vector<int> g[200009];
int col[200009];
int depth[200009];
int dad[200009][30];
void dfs1(int v, int prt){
for(int i = 1; i < 30; ++i)
dad[v][i] = dad[dad[v][i-1]][i-1];
for(auto u : g[v])
if(u != prt){
depth[u] = depth[v]+1;
dad[u][0] = v;
dfs1(u, v);
}
}
int lca(int x, int y){
if(depth[x] < depth[y])
swap(x, y);
for(int i = 29; i >= 0; --i)
if(depth[dad[x][i]] >= depth[y])
x = dad[x][i];
if(x == y) return x;
for(int i = 29; i >= 0; --i)
if(dad[x][i] != dad[y][i]){
x = dad[x][i];
y = dad[y][i];
}
return dad[x][0];
}
int root[200009];
int prt[200009];
int sz[200009];
set<int> con[200009];
int find(int x){
if(prt[x] == x) return x;
return prt[x] = find(prt[x]);
}
void merge(int x, int y){
x = find(x);
y = find(y);
if(x == y) return;
if(con[y].size() > con[x].size())
swap(x, y);
con[y].erase(x);
con[x].erase(y);
for(auto u : con[y])
con[x].insert(u);
root[x] = min(root[x], root[y]);
prt[y] = x;
sz[x] += sz[y];
}
void dfs2(int v, int prt){
for(auto u : g[v])
if(u != prt){
dfs2(u, v);
if(col[u] != col[v] && depth[u] != root[col[u]]){
con[col[u]].insert(col[v]);
if(con[col[v]].find(col[u]) != con[col[v]].end()){
merge(col[u], col[v]);
col[v] = find(col[v]);
}
}
}
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(nullptr);
//freopen("in.txt", "r", stdin);
//freopen("out.txt", "w", stdout);
cin >> n >> k;
for(int i = 0; i < n-1; ++i){
int x, y;
cin >> x >> y;
g[x].pb(y);
g[y].pb(x);
}
for(int i = 1; i <= n; ++i)
cin >> col[i];
depth[1] = 1;
dfs1(1, 0);
for(int i = 1; i <= n; ++i){
if(root[col[i]] == 0)
root[col[i]] = i;
else
root[col[i]] = lca(root[col[i]], i);
}
for(int i = 1; i <= k; ++i){
root[i] = depth[root[i]];
prt[i] = i;
sz[i] = 1;
}
dfs2(1, 0);
int ans = 1e9;
for(int i = 1; i <= k; ++i){
int x = find(i);
if(con[x].empty())
ans = min(ans, sz[x]);
}
assert(ans != 1e9);
cout << ans-1 << '\n';
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Runtime error |
34 ms |
29048 KB |
Execution killed with signal 11 (could be triggered by violating memory limits) |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Runtime error |
34 ms |
29048 KB |
Execution killed with signal 11 (could be triggered by violating memory limits) |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Runtime error |
587 ms |
135168 KB |
Execution killed with signal 11 (could be triggered by violating memory limits) |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Runtime error |
34 ms |
29048 KB |
Execution killed with signal 11 (could be triggered by violating memory limits) |
2 |
Halted |
0 ms |
0 KB |
- |