This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include<bits/stdc++.h>
using namespace std;
#define pb push_back
#define fi first
#define se second
#define sz(a) (int)(a.size())
#define all(a) a.begin(),a.end()
#define lb lower_bound
#define ub upper_bound
#define owo ios_base::sync_with_stdio(0);cin.tie(0);
#define MOD (ll)(998244353)
#define INF (ll)(1e18)
#define debug(...) fprintf(stderr, __VA_ARGS__),fflush(stderr)
#define time__(d) for(long blockTime = 0; (blockTime == 0 ? (blockTime=clock()) != 0 : false);\
debug("%s time : %.4fs\n", d, (double)(clock() - blockTime) / CLOCKS_PER_SEC))
typedef long long int ll;
typedef long double ld;
typedef pair<ll,ll> PII;
typedef pair<int,int> pii;
typedef vector<vector<int>> vii;
typedef vector<vector<ll>> VII;
const int MAXN = 5e5+5;
int tin[MAXN],tout[MAXN],dep[MAXN],par[21][MAXN],c[MAXN];
vector<int>adj[MAXN];
int timer = 0;
void dfs(int v,int u){
tin[v] = ++timer;
for(int x:adj[v]){
if(x == u)continue;
par[0][x] = v;
dep[x] = dep[v]+1;
for(int i=1;i<=20;i++){
par[i][x] = par[i-1][par[i-1][x]];
}
dfs(x,v);
}
tout[v] = timer;
}
bool ancestor(int v,int u){return(tin[v] >= tin[u] && tout[u] >= tout[v]);}//is u an ancestor of v?
int lca(int v,int u){
if(dep[v] > dep[u])swap(u,v);
int dist = dep[u]-dep[v];
for(int i=0;dist;i++,dist>>=1){
if(dist&1)u = par[i][u];
}
if(u == v)return u;
for(int i=20;i>=0;i--){
if(par[i][v] != par[i][u]){
u = par[i][u];
v = par[i][v];
}
}
return par[0][v];
}
vector<int>color[MAXN];
int lc[MAXN],p[MAXN];
int find(int x){
if(p[x] == x)return x;
return p[x] = find(p[x]);
}
void unite(int a,int b){
a = find(a);
b = find(b);
if(a==b)return;
if(dep[a] < dep[b])swap(a,b);
p[a] = b;
if(dep[lc[a]] < dep[lc[b]])lc[b] = lc[a];
}
void dfs2(int v,int u){
int tar = find(v);
tar = dep[lc[tar]];
int cur = v;
//cout<<"node"<<" "<<cur<<'\n';
while(dep[cur] > tar){
int pp = find(cur);
if(dep[pp] <= tar)break;
unite(cur,par[0][pp]);
//cout<<cur<<" "<<pp<<'\n';
cur = par[0][pp];
}
//cout<<'\n';
for(int x:adj[v]){
if(x==u)continue;
dfs2(x,v);
}
}
int main()
{
owo
int n,k;
cin>>n>>k;
for(int i=0;i<n-1;i++){
int v,u;
cin>>v>>u;
v--;u--;
adj[v].pb(u);
adj[u].pb(v);
}
for(int i=0;i<n;i++){
cin>>c[i];
c[i]--;
color[c[i]].pb(i);
p[i] = i;
}
dfs(0,-1);
for(int i=0;i<k;i++){
int v = -1;
for(int x:color[i]){
if(v == -1)v = x;
else v = lca(x,v);
}
for(int x:color[i])lc[x] = v;
//cout<<i<<" "<<v<<'\n';
}
//for(int i=0;i<n;i++)cout<<lc[i]<<" ";
//cout<<'\n';
dfs2(0,-1);
//for(int i=0;i<n;i++)cout<<find(i)<<" ";
vector<set<int>>s(n);
for(int i=0;i<n;i++){
for(int x:adj[i]){
if(find(x) != find(i)){
s[find(x)].insert(find(i));
s[find(i)].insert(find(x));
}
}
}
int ans = 0;
for(int i=0;i<n;i++){
if(find(i) == i && sz(s[i]) == 1)ans++;
}
cout<<(ans+1)/2;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |