#include <bits/stdc++.h>
using namespace std;
#define FOR(i,a,b) for(int i=a;i<=(int)b;i++)
#define FORD(i,a,b) for(int i=a;i>=(int)b;i--)
#define ll long long
#define fi first
#define se second
#define pb push_back
#define all(a) a.begin(),a.end()
#define BIT(mask,i) ((mask>>(i))&1)
#define MASK(a) (1LL<=((a)))
#define uni(v) sort(all(v)); v.resize(unique(all(v)) - v.begin())
#define pii pair <int, int>
#define vi vector <int>
#define vl vector <ll>
template <class A,class B>
bool maximize(A &a, const B b)
{
if(a < b){ a = b; return 1;} return 0;
}
template <class A,class B>
bool minimize(A &a, const B b)
{
if(a > b){ a = b; return 1;} return 0;
}
const int maxn = 5e5 + 5;
int n, k;
vi g[maxn], state[maxn], adj[maxn];
int par[maxn], h[maxn], tin[maxn], tout[maxn], euler, belong[maxn];
int sz[maxn], heavy[maxn], pos[maxn], cur[maxn], head[maxn], crt, timedfs;
void dfs(int u)
{
sz[u] = 1;
tin[u] = ++ euler;
for(int v : g[u])
{
if(v == par[u]) continue;
par[v] = u;
h[v] = h[u] + 1;
dfs(v);
sz[u]+= sz[v];
if(sz[v] > sz[heavy[u]]) heavy[u] = v;
}
tout[u] = euler;
}
void dfs_hld(int u)
{
if(!head[crt]) head[crt] = u;
pos[u] = ++ timedfs;
cur[u] = crt;
if(heavy[u])dfs_hld(heavy[u]);
for(int v : g[u]) if(v != par[u] && v != heavy[u])
{
crt++;
dfs_hld(v);
}
}
int LCA(int u, int v)
{
while(cur[u] != cur[v])
{
if(cur[u] > cur[v]) u = par[head[cur[u]]];
else v = par[head[cur[v]]];
}
if(h[u] < h[v]) return u;
return v;
}
int st[4 * maxn], lazy[4 * maxn];
void apply(int id, int val)
{
maximize(st[id], val);
maximize(lazy[id], val);
}
void pushdown(int id)
{
if(lazy[id])
{
apply(id << 1, lazy[id]);
apply(id << 1 | 1, lazy[id]);
lazy[id] = 0;
}
}
void update(int id, int l, int r, int u, int v, int val)
{
if(l > v || r < u) return;
if(l >= u && r <= v)
{
apply(id, val);
return;
}
int mid = (l + r) >> 1;
pushdown(id);
update(id << 1, l, mid, u, v, val);
update(id << 1 | 1, mid + 1, r, u, v, val);
st[id] = max(st[id << 1], st[id << 1 | 1]);
}
int get(int id, int l, int r, int u, int v)
{
if(l > v || r < u) return 0;
if(l >= u && r <= v) return st[id];
int mid = (l + r) >> 1;
pushdown(id);
return max(get(id << 1, l, mid, u, v), get(id << 1 | 1, mid + 1, r, u, v));
}
int isparent(int u, int v) {return tin[u] < tin[v] && tout[u] > tout[v];}
int lab[maxn];
int find(int x) {return lab[x] == x ? x : lab[x] = find(lab[x]);}
bool joint(int u, int v)
{
int x = find(u), y = find(v);
if(x == y) return false;
lab[y] = x;
return true;
}
int getpath(int u, int v)
{
int p = LCA(u, v);
int res = 0;
while(cur[u] != cur[p])
{
maximize(res, get(1, 1, n, pos[head[cur[u]]], pos[u]));
u = par[head[cur[u]]];
}
while(cur[v] != cur[p])
{
maximize(res, get(1, 1, n, pos[head[cur[v]]], pos[v]));
v = par[head[cur[v]]];
}
if(pos[u] < pos[v]) maximize(res, get(1, 1, n, pos[u], pos[v]));
else maximize(res, get(1, 1, n, pos[v], pos[u]));
return res;
}
void uppath(int u, int v, int id)
{
int p = LCA(u, v);
while(cur[u] != cur[p])
{
update(1, 1, n, pos[head[cur[u]]], pos[u], id);
u = par[head[cur[u]]];
}
while(cur[v] != cur[p])
{
update(1, 1, n, pos[head[cur[v]]], pos[v], id);
v = par[head[cur[v]]];
}
if(pos[u] < pos[v]) update(1, 1, n, pos[u], pos[v], id);
else update(1, 1, n, pos[v], pos[u], id);
}
void DFS(int u, int id, int type)
{
if(getpath(u, u) > 0)
joint(getpath(u, u), id);
for(int v : adj[u])
{
DFS(v, id, type);
if(type == 0 && (getpath(u, v) > 0))
{
joint(getpath(u, v), id);
}
else uppath(u, v, id);
}
}
void cook(int id)
{
vector <int> vers;
for(int u : state[id]) vers.push_back(u);
sort(all(vers), [&](int u, int v) {return tin[u] < tin[v];});
int k = vers.size() - 1;
for(int i = 1; i <= k; i++) vers.push_back(LCA(vers[i - 1], vers[i]));
sort(all(vers), [&](int u, int v) {return tin[u] < tin[v];});
vers.resize(unique(all(vers)) - vers.begin());
vector <int> p;
for(int v : vers)
{
while((int)p.size() >= 2 && !isparent(p.back(), v))
{
adj[p[p.size() - 2]].push_back(p.back());
p.pop_back();
}
p.push_back(v);
}
while(p.size() >= 2)
{
adj[p[p.size() - 2]].push_back(p.back());
p.pop_back();
}
DFS(p[0], id, 0);
DFS(p[0], id, 1);
for(int v : vers) adj[v].clear();
}
int leaf = 0;
void demla(int u, int p)
{
bool nc = 0;
for(int v :adj[u]) if(v != p)
{
nc = 1;
demla(v, u);
}
leaf+= (nc == 0);
}
void solve()
{
cin >> n >> k;
for(int i = 1; i < n; i++)
{
int u, v; cin >> u >> v;
g[u].pb(v); g[v].pb(u);
}
for(int i = 1; i <= n; i++)
{
int x; cin >> x;
belong[i]= x;
state[x].push_back(i);
}
dfs(1);
dfs_hld(1);
for(int i = 1; i <= k; i++)
lab[i] = i;
for(int i = 1; i <= k; i++)
cook(i);
for(int i = 1; i <= n; i++)
{
int X = find(belong[i]);
for(int v : g[i])
{
int Y = find(belong[v]);
if(X == Y) continue;
adj[X].push_back(Y);
}
}
demla(1, 1);
cout << leaf / 2;
}
signed main()
{
ios_base::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
#define kieuoanh "kieuoanh"
if(fopen(kieuoanh".inp","r"))
{
freopen(kieuoanh".inp","r",stdin);
freopen(kieuoanh".out","w",stdout);
}
int tst = 1;
// cin >> tst;
while(tst--) solve();
return 0;
}
컴파일 시 표준 에러 (stderr) 메시지
mergers.cpp: In function 'int main()':
mergers.cpp:270:16: warning: ignoring return value of 'FILE* freopen(const char*, const char*, FILE*)' declared with attribute 'warn_unused_result' [-Wunused-result]
270 | freopen(kieuoanh".inp","r",stdin);
| ~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~
mergers.cpp:271:16: warning: ignoring return value of 'FILE* freopen(const char*, const char*, FILE*)' declared with attribute 'warn_unused_result' [-Wunused-result]
271 | freopen(kieuoanh".out","w",stdout);
| ~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |