#include <bits/stdc++.h>
using namespace std;
using ll = long long;
using pii = pair<int, int>;
#define pb push_back
#define ff first
#define ss second
#define ins insert
struct FT{
vector<int> bit;
int n;
FT(int ns){
n = ns;
bit.resize(n + 1);
}
void upd(int v, int k){
while (v <= n){
bit[v] += k;
v |= (v + 1);
}
}
int get(int v){
int out = 0;
while (v > 0){
out += bit[v];
v = (v & (v + 1)) - 1;
}
return out;
}
int get(int l, int r){
return get(r) - get(l - 1);
}
};
int main(){
ios_base::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
int n, k; cin>>n>>k;
vector<int> g[n + 1];
for (int i = 1; i < n; i++){
int x, y; cin>>x>>y;
g[x].pb(y);
g[y].pb(x);
}
vector<int> ct[k + 1];
for (int i = 1; i <= n; i++){
int x; cin>>x;
ct[x].pb(i);
}
vector<int> sz(n + 1), h(n + 1), d(n + 1), p(n + 1);
function<void(int, int)> fill = [&](int v, int pr){
p[v] = pr;
sz[v] = 1;
for (int i: g[v]){
if (i == pr) continue;
d[i] = d[v] + 1;
fill(i, v);
if (sz[i] > sz[h[v]]){
h[v] = i;
}
sz[v] += sz[i];
}
};
fill(1, 0);
vector<int> head(n + 1), pos(n + 1);
int timer = 0;
function<void(int, int)> fill_hld = [&](int v, int k){
head[v] = k;
pos[v] = ++timer;
if (!h[v]) return;
fill_hld(h[v], k);
for (int i: g[v]){
if (pos[i] || i == h[v]) continue;
fill_hld(i, i);
}
};
fill_hld(1, 1);
vector<pii> all;
auto add = [&](int x, int y){
while (head[x] != head[y]){
if (d[head[x]] > d[head[y]]) swap(x, y);
all.pb({pos[head[y]], pos[y]});
y = p[head[y]];
}
if (d[x] > d[y]) swap(x, y);
all.pb({pos[x], pos[y]});
};
auto un = [&](){
vector<pii> ret;
sort(all.begin(), all.end());
for (auto [l, r]: all){
if (ret.empty()){
ret.pb({l, r});
continue;
}
auto [l1, r1] = ret.back();
if (r1 + 1 < l){
ret.pb({l, r});
}
else {
ret.pb({l1, r});
}
}
return ret;
};
auto cp = [&](){
vector<pii> ret;
int pre = 1;
for (auto [l, r]: all){
if (pre < l){
ret.pb({pre, l - 1});
}
pre = r + 1;
}
if (pre <= n){
ret.pb({pre, n});
}
return ret;
};
vector<pii> end[n + 1];
for (int i = 1; i <= k; i++){
all.clear();
add(ct[i][0], ct[i][0]);
for (int j = 0; j + 1 < ct[i].size(); j++){
add(ct[i][j], ct[i][j + 1]);
}
all = un();
all = cp();
for (auto [l, r]: all){
end[r].pb({l, i});
}
}
vector<int> l(k + 1, n), r(k + 1), end1[n + 1];
for (int i = 1; i <= k; i++){
for (int j: ct[i]){
l[i] = min(l[i], pos[j]);
r[i] = max(r[i], pos[j]);
}
end1[r[i]].pb(l[i]);
}
vector<int> out(k + 1, k);
FT T(n);
for (int i = 1; i <= n; i++){
for (int t: end1[i]){
T.upd(t, 1);
}
for (auto [t, j]: end[i]){
out[j] -= T.get(t, n);
}
}
int ans = k;
for (int i = 1; i <= k; i++){
ans = min(ans, out[i]);
}
cout<<--ans<<"\n";
}
Compilation message
capital_city.cpp: In function 'int main()':
capital_city.cpp:134:31: warning: comparison of integer expressions of different signedness: 'int' and 'std::vector<int>::size_type' {aka 'long unsigned int'} [-Wsign-compare]
134 | for (int j = 0; j + 1 < ct[i].size(); j++){
| ~~~~~~^~~~~~~~~~~~~~
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
0 ms |
348 KB |
Output is correct |
2 |
Correct |
0 ms |
452 KB |
Output is correct |
3 |
Correct |
1 ms |
456 KB |
Output is correct |
4 |
Incorrect |
0 ms |
348 KB |
Output isn't correct |
5 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
0 ms |
348 KB |
Output is correct |
2 |
Correct |
0 ms |
452 KB |
Output is correct |
3 |
Correct |
1 ms |
456 KB |
Output is correct |
4 |
Incorrect |
0 ms |
348 KB |
Output isn't correct |
5 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
132 ms |
52140 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
0 ms |
348 KB |
Output is correct |
2 |
Correct |
0 ms |
452 KB |
Output is correct |
3 |
Correct |
1 ms |
456 KB |
Output is correct |
4 |
Incorrect |
0 ms |
348 KB |
Output isn't correct |
5 |
Halted |
0 ms |
0 KB |
- |