#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define endl "\n"
#define INF 1000000000
#define LINF 10000000000000000LL
#define pb push_back
#define all(x) x.begin(), x.end()
#define len(s) (int)s.size()
#define test_case { int t; cin>>t; while(t--)solve(); }
#define single_case solve();
#define line cerr<<"----------"<<endl;
#define ios { ios_base::sync_with_stdio(false); cin.tie(NULL); cout.tie(NULL); cerr.tie(NULL); }
#define mod 1000000007LL
/*
5 4
1 2
2 3
2 4
1 5
1 2 3 4 1
*/
const int N = 1e6;
int n, k, timer, tin[N], cnt[N], b[N], c[N], sz[N], tout[N];
vector<int> g[N], w[N];
vector<int> grupa;
set<int> s;
int up[N/2+2][20];
int lcagrupe[N];
set<int> lcaa[N];
bool is_ancestor(int u, int v)
{
return tin[u] <= tin[v] && tout[u] >= tout[v];
}
int lca(int u, int v)
{
if (is_ancestor(u, v))
return u;
if (is_ancestor(v, u))
return v;
for (int i = 19; i >= 0; --i) {
if (!is_ancestor(up[u][i], v))
u = up[u][i];
}
return up[u][0];
}
void dfstime(int u, int pret)
{
up[u][0] = pret;
if(u==1) up[u][0] = u;
tin[u] = ++timer;
w[b[u]].pb(timer);
for(int x : g[u])
{
if(x==pret) continue;
dfstime(x, u);
sz[u] += sz[x];
}
tout[u] = ++timer;
sz[u]++;
}
void dfspotpuno(int u, int pret)
{
for(int x : g[u])
{
if(x==pret) continue;
if(u==1) s.clear();
dfspotpuno(x, u);
}
if(u==1) return;
s.insert(b[u]);
for(int x : lcaa[u]) s.erase(x);
if(!len(s)) c[u] = 1;
}
int dfs(int u, int pret, int stanje)
{
int cnt = 0;
for(int x : g[u])
{
if(x==pret) continue;
if(!stanje)
{
if(c[x]) cnt += dfs(x, u, 1);
else cnt += dfs(x, u, 0);
}
else
{
cnt += dfs(x, u, stanje+1);
}
}
if(u==1) return 0;
if(!cnt&&c[u]) cnt++;
if(stanje==1) grupa.pb(cnt);
//cout<<u<<' '<<pret<<' '<<cnt<<' '<<c[u]<<' '<<stanje<<endl;
return cnt;
}
int main()
{
ios
cin>>n>>k;
for(int i = 0;i<n-1;i++)
{
int a, b;
cin>>a>>b;
g[a].pb(b);
g[b].pb(a);
}
for(int i = 1;i<=n;i++) cin>>b[i], cnt[b[i]]++;
dfstime(1, -1);
for(int k = 1;k<=19;k++)
{
for(int u = 1;u<=n;u++) up[u][k] = up[up[u][k-1]][k-1];
}
for(int i = 1;i<=n;i++)
{
if(lcagrupe[b[i]]==0) lcagrupe[b[i]] = i;
else lcagrupe[b[i]] = lca(lcagrupe[b[i]], i);
}
for(int i = 1;i<=n;i++)
{
if(lcagrupe[b[i]]) lcaa[lcagrupe[b[i]]].insert(b[i]);
}
dfspotpuno(1, -1);
dfs(1, -1, 0);
int ans = 0;
if(len(grupa)==1)
{
cout<<1+grupa[0]/2;
return 0;
}
/*cout<<len(grupa)<<endl;
for(int x : grupa) cout<<x<<' ';
cout<<endl;*/
for(int x : grupa) ans += x;
ans = (ans+1)/2;
cout<<ans;
return 0;
}
/*
10 8
1 2
2 3
2 4
1 5
5 8
5 6
6 7
8 9
8 10
6 6 7 8 6 5
5
4
3
3
*/
/*
10 6
1 2
2 3
2 4
4 5
4 6
1 7
7 8
8 9
9 10
1 3 5 6 6 6 1 2 4 4
*/
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
43 ms |
94276 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
43 ms |
94276 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
43 ms |
94276 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
100 ms |
107676 KB |
Output is correct |
2 |
Correct |
171 ms |
117952 KB |
Output is correct |
3 |
Incorrect |
44 ms |
94808 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Incorrect |
43 ms |
94276 KB |
Output isn't correct |
2 |
Halted |
0 ms |
0 KB |
- |