// Make the best become better
// No room for laziness
#include<bits/stdc++.h>
#define int long long
#define pb push_back
#define fi first
#define se second
using namespace std;
using ll = long long;
using ld = long double;
using ull = unsigned long long;
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());
const int maxN = 1e6 + 5;
const int mod = 1e9 + 7;
const ll oo = 1e18;
int n, m;
vector<int> adj[maxN];
vector<int> vc[maxN];
int f[maxN][20];
int depth[maxN];
pair<int, int> e[maxN];
int sum[maxN];
int lab[maxN];
int deg[maxN];
void ReadInput()
{
cin >> n >> m;
for(int i=1; i<n; i++)
{
int u, v;
cin >> u >> v;
e[i] = {u, v};
adj[u].pb(v);
adj[v].pb(u);
}
for(int i=1; i<=n; i++)
{
int x;
cin >> x;
vc[x].pb(i);
}
}
int Findset(int u)
{
return lab[u] < 0 ? u : lab[u] = Findset(lab[u]);
}
void Unite(int u, int v)
{
int r = Findset(u), s = Findset(v);
if(r == s) return;
if(lab[r] > lab[s]) swap(r, s);
lab[r] += lab[s];
lab[s] = r;
}
void dfs(int u, int par)
{
for(int v : adj[u])
{
if(v == par) continue;
f[v][0] = u;
depth[v] = depth[u] + 1;
for(int i=1; i<=18; i++)
f[v][i] = f[f[v][i - 1]][i - 1];
dfs(v, u);
}
}
int lca(int u, int v)
{
if(depth[u] < depth[v]) swap(u, v);
int k = depth[u] - depth[v];
for(int i=18; i>=0; i--)
if((k >> i) & 1)
u = f[u][i];
if(u == v) return u;
for(int i=18; i>=0; i--)
if(f[u][i] != f[v][i])
{
u = f[u][i];
v = f[v][i];
}
return f[u][0];
}
void dfs1(int u, int par)
{
for(int v : adj[u])
{
if(v == par) continue;
dfs1(v, u);
sum[u] += sum[v];
}
if(sum[u]) Unite(u, par);
}
void Solve()
{
dfs(1, 0);
for(int i=1; i<=m; i++)
{
if(vc[i].empty()) continue;
int tmp = vc[i][0];
for(int v : vc[i])
{
sum[v] = -1;
tmp = lca(tmp, v);
}
sum[tmp] += vc[i].size();
}
memset(lab, -1, sizeof lab);
dfs1(1, 0);
map<pair<int, int>, int> mp;
for(int i=1; i<n; i++)
{
int u = Findset(e[i].fi), v = Findset(e[i].se);
if(u == v || mp[{u, v}]) continue;
deg[u]++;
deg[v]++;
mp[{u, v}] = mp[{v, u}] = 1;
}
int cnt = 0;
for(int i=1; i<=n; i++)
{
if(lab[i] > 0) continue;
if(deg[i] == 1) cnt++;
}
cout << (cnt + 1) / 2;
}
int32_t main()
{
ios_base::sync_with_stdio(false);
cin.tie(nullptr);
ReadInput();
Solve();
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
26 ms |
55120 KB |
Output is correct |
2 |
Incorrect |
26 ms |
55120 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
26 ms |
55120 KB |
Output is correct |
2 |
Incorrect |
26 ms |
55120 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
26 ms |
55120 KB |
Output is correct |
2 |
Incorrect |
26 ms |
55120 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
97 ms |
78916 KB |
Output is correct |
2 |
Correct |
206 ms |
94376 KB |
Output is correct |
3 |
Incorrect |
31 ms |
56012 KB |
Output isn't correct |
4 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Correct |
26 ms |
55120 KB |
Output is correct |
2 |
Incorrect |
26 ms |
55120 KB |
Output isn't correct |
3 |
Halted |
0 ms |
0 KB |
- |