This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define ld double
#define sz(x) (int)x.size()
#define all(x) x.begin(),x.end()
#define pb emplace_back
#define X first
#define Y second
const int N = 2e5 + 5;
typedef pair<int,int> ii;
int p[N];
int s[N];
int init(int n) {
iota(p + 1,p + 1 + n,1);
fill(s + 1,s + 1 + n,1);
return 1;
}
int lead(int x) {
return p[x] == x ? x : p[x] = lead(p[x]);
}
int join(int x,int y) {
x = lead(x);
y = lead(y);
if (x == y) return 0;
if (s[x] < s[y])
swap(x,y);
p[y] = x;
s[x] += s[y];
return 1;
}
int c[N];
int a[N];
int b[N];
vector<int> S[N];
vector<int> g[N];
int nCh[N], pos[N];
int led[N], par[N];
int arr[N], tot = 0;
void dfs(int u,int p) {
par[u] = p;
nCh[u] = 1;
for(int v : g[u]) if (v != p) {
dfs(v,u);
nCh[u] += nCh[v];
}
}
void hld(int u,int ok) {
if (ok) led[u] = u;
else led[u] = led[par[u]];
pos[u] = ++tot;
arr[tot] = u;
int B = 0;
for(int v : g[u]) if (v != par[u])
if (nCh[B] < nCh[v])
B = v;
if (B) hld(B,0);
for(int v : g[u]) if (v != par[u] && v != B)
hld(v,1);
}
bool is_anc(int u,int v) { return pos[u] <= pos[v] && pos[u] + nCh[u] >= pos[v] + nCh[v]; }
int lca(int u,int v) {
while (!is_anc(led[u],v)) u = par[led[u]];
while (!is_anc(led[v],u)) v = par[led[v]];
if (pos[u] > pos[v])
swap(u,v);
assert(is_anc(u,v));
return u;
}
int tr[N << 1];
int ok[N];
vector<int> vec[N];
bool have[N];
bool need[N];
int main() {
ios_base::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
int n; cin >> n;
int k; cin >> k;
for(int i = 1 ; i < n ; ++i) {
cin >> a[i];
cin >> b[i];
}
for(int i = 1 ; i <= n ; ++i)
cin >> c[i];
init(n);
for(int i = 1 ; i < n ; ++i)
if (c[a[i]] == c[b[i]])
join(a[i],b[i]);
for(int i = 1 ; i < n ; ++i)
if (c[a[i]] != c[b[i]]) {
int x = lead(a[i]);
int y = lead(b[i]);
g[x].pb(y);
g[y].pb(x);
}
for(int i = 1 ; i <= n ; ++i)
if (i == p[i])
S[c[i]].pb(i);
dfs(lead(1),0);
hld(lead(1),1);
auto upd = [&](int p,int v) {
for(tr[p += tot - 1] = v ; p > 1 ; p >>= 1)
tr[p >> 1] = min(tr[p],tr[p ^ 1]);
};
auto get = [&](int l,int r) {
l += tot - 1;
r += tot;
int ans = 1e9;
for(; l < r ; l >>= 1, r >>= 1) {
if (l & 1) ans = min(ans,tr[l++]);
if (r & 1) ans = min(ans,tr[--r]);
}
return ans;
};
for(int i = 1 ; i <= tot ; ++i)
upd(i,1e9);
for(int i = 1 ; i <= k ; ++i)
sort(all(S[i]),[&](int x,int y) {
return pos[x] < pos[y];
});
vector<int> perms(k);
iota(all(perms),1);
sort(all(perms),[&](int x,int y) {
int a = lca(S[x][0],S[x].back());
int b = lca(S[y][0],S[y].back());
return pos[a] < pos[b];
});
for(int i : perms) {
int R = lca(S[i][0],S[i].back());
int P = pos[R];
for(int x : S[i]) {
while (!is_anc(led[x],R)) {
P = min(get(pos[led[x]],pos[x]),P);
x = par[led[x]];
}
P = min(get(pos[R],pos[x]),P);
}
for(int x : S[i])
upd(pos[x],P);
if (P == pos[S[i][0]])
ok[i] = 1;
}
int ans = 1e9;
for(int i = 1 ; i <= k ; ++i) if (ok[i]) {
need[i] = 1;
queue<int> qu;
stack<int> st;
for(int x : S[i])
qu.push(x),
have[x] = 1;
int cur = 0;
while (qu.size()) {
int u = qu.front();
qu.pop();
st.push(u);
if (get(pos[u],pos[u]) != pos[S[i][0]]) {
cur = 1e9;
while (qu.size()) {
st.push(qu.front());
qu.pop();
}
break;
}
if (u != S[i][0])
u = par[u];
if(!have[u]) {
have[u] = 1;
qu.push(u);
if(!need[c[u]]) {
need[c[u]] = 1; ++cur;
for(int x : S[c[u]])
if(!have[x]) {
have[x] = 1;
qu.push(x);
}
}
}
}
while (st.size()) {
int u = st.top(); st.pop();
have[u] = 0;
need[c[u]] = 0;
}
if (ans > cur)
ans = cur;
}
cout << ans << endl;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |