#include <bits/stdc++.h>
using namespace std;
const int MX=2e5+5,LOG=20;
int N,K;
int A[MX],B[MX],S[MX];
vector<int> adj[MX], group[MX];
int id[MX][20], up[MX][20];
int timer=0,tin[MX],tout[MX];
void dfs(int v, int p) {
tin[v]=++timer;
up[v][0]=p;
for(int k=1;k<20;k++) up[v][k]=up[up[v][k-1]][k-1];
for(auto u:adj[v]) {
if(u==p) continue;
dfs(u,v);
}
tout[v]=timer;
}
bool isAnc(int anc, int v) {
return tin[anc]<=tin[v] && tout[v]<=tout[anc];
}
int LCA(int u, int v) {
if(isAnc(u,v)) return u;
if(isAnc(v,u)) return v;
for(int k=19;k>=0;k--) {
if(up[u][k]!=0 && !isAnc(up[u][k],v)) u=up[u][k];
}
return up[u][0];
}
vector<int> scc[MX*LOG], rev[MX*LOG];
vector<int> ord;
bool vis[MX*LOG];
void dfs0(int v) {
vis[v]=true;
for(auto u:scc[v]) {
if(!vis[u]) {
dfs0(u);
}
}
ord.push_back(v);
}
vector<int> comp;
int cnt[MX*LOG], head[MX*LOG], deg[MX*LOG];
void dfs1(int v) {
vis[v]=true;
comp.push_back(v);
for(auto u:rev[v]) {
if(!vis[u]) {
dfs1(u);
}
}
}
void add(int v, int anc, int base) {
for(int k=19;k>=0;k--) {
if(up[v][k]!=0 && !isAnc(up[v][k],anc)) {
assert(id[v][k]!=0);
scc[base].push_back(id[v][k+1]);
v=up[v][k];
}
}
scc[base].push_back(id[v][0]);
}
int main() {
cin.tie(0); ios_base::sync_with_stdio(0);
cin>>N>>K;
for(int i=0;i<N-1;i++) {
cin>>A[i]>>B[i];
adj[A[i]].push_back(B[i]);
adj[B[i]].push_back(A[i]);
}
for(int i=1;i<=N;i++) {
cin>>S[i];
group[S[i]].push_back(i);
}
int z=K;
for(int i=1;i<=N;i++) {
for(int j=0;j<20;j++) {
id[i][j]=++z;
}
}
dfs(1,0);
for(int i=1;i<=N;i++) {
for(int j=1;j<20;j++) {
scc[id[i][j]].push_back(id[i][j-1]);
if(up[i][j-1]!=0) scc[id[i][j]].push_back(id[up[i][j-1]][j-1]);
}
scc[id[i][0]].push_back(S[i]);
}
for(int i=1;i<=K;i++) {
int cur=group[i][0];
for(auto x:group[i]) cur=LCA(cur,x);
for(auto x:group[i]) {
if(up[x][0]!=0 && !isAnc(up[x][0],cur))
add(up[x][0],cur,i);
}
if(S[cur]!=i) {
scc[i].push_back(S[cur]);
}
}
for(int i=1;i<=z;i++) {
for(auto j:scc[i]) {
rev[j].push_back(i);
}
}
for(int i=1;i<=z;i++) {
if(vis[i]) continue;
dfs0(i);
}
reverse(ord.begin(),ord.end());
memset(vis,false,sizeof vis);
for(auto x:ord) {
if(vis[x]) continue;
dfs1(x);
for(auto v:comp) {
head[v]=comp.front();
cnt[comp.front()]+=v<=K;
}
comp.clear();
}
for(int i=1;i<=z;i++) {
for(auto j:scc[i]) {
if(head[i]!=head[j]) {
deg[head[i]]+=1;
}
}
}
int ans=K;
for(int i=1;i<=K;i++) {
if(deg[head[i]]==0) ans=min(ans,cnt[head[i]]);
}
cout<<ans-1<<'\n';
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
106 ms |
203860 KB |
Output is correct |
2 |
Correct |
55 ms |
203852 KB |
Output is correct |
3 |
Correct |
56 ms |
203860 KB |
Output is correct |
4 |
Incorrect |
55 ms |
203856 KB |
Output isn't correct |
5 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
106 ms |
203860 KB |
Output is correct |
2 |
Correct |
55 ms |
203852 KB |
Output is correct |
3 |
Correct |
56 ms |
203860 KB |
Output is correct |
4 |
Incorrect |
55 ms |
203856 KB |
Output isn't correct |
5 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Execution timed out |
3010 ms |
524288 KB |
Time limit exceeded |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
106 ms |
203860 KB |
Output is correct |
2 |
Correct |
55 ms |
203852 KB |
Output is correct |
3 |
Correct |
56 ms |
203860 KB |
Output is correct |
4 |
Incorrect |
55 ms |
203856 KB |
Output isn't correct |
5 |
Halted |
0 ms |
0 KB |
- |