#include <cstdio>
#include <stdio.h>
#include <stdbool.h>
#include <iostream>
#include <map>
#include <vector>
#include <climits>
#include <stack>
#include <string>
#include <queue>
#include <algorithm>
#include <set>
#include <unordered_set>
#include <unordered_map>
#include <cmath>
#include <cctype>
#include <bitset>
#include <iomanip>
#include <cstring>
#include <numeric>
#include <cassert>
using namespace std;
#define int long long
#define pii pair<int, int>
#define mp make_pair
#define pb push_back
#define fi first
#define se second
int counter=0, ans=0;
vector<int> depth, in, rev, val, die;
vector<vector<int> > graph, twok;
void dfs(int node, int p, int d){
in[node]=++counter;
rev[counter]=node;
depth[node]=d;
twok[node][0]=p;
for (int i=1; i<20; ++i)twok[node][i]=twok[twok[node][i-1]][i-1];
for (auto num:graph[node])if (num!=p)dfs(num, node, d+1);
}
int lca(int a, int b){
if (depth[a]<depth[b])swap(a, b);
for (int i=0, k=depth[a]-depth[b]; i<20; ++i)if (k&(1<<i))a=twok[a][i];
if (a==b)return a;
for (int i=19; i>=0; --i)if (twok[a][i]!=twok[b][i])a=twok[a][i], b=twok[b][i];
return twok[a][0];
}
int dfs2(int node, int p){
int res=val[node];
for (auto num:graph[node])if (num!=p){
int temp=dfs2(num, node);
if (!temp)die[num]=1;
res+=temp;
}
return res;
}
int dfs3(int node, int p, bool got){
int leaf=0;
for (auto num:graph[node])if (num!=p)leaf+=dfs3(num, node, got||die[num]);
ans=max(ans, leaf+got);
return max(leaf, die[node]);
}
int32_t main(){
ios_base::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
int n, k, a, b;
cin>>n>>k;
die.resize(n+1, 0);
in.resize(n+1);
rev.resize(n+1);
graph.resize(n+1);
depth.resize(n+1);
val.resize(n+1);
twok.resize(n+1, vector<int>(20));
vector<vector<int> > vect(k+1);
for (int i=1; i<n; ++i){
cin>>a>>b;
graph[a].pb(b);
graph[b].pb(a);
}
dfs(1, 1, 0);
for (int i=1; i<=n; ++i)cin>>a, vect[a].pb(i);
for (int i=1; i<=k; ++i){
int mx=LLONG_MIN/2, mn=LLONG_MAX/2;
for (auto a:vect[i])mx=max(mx, in[a]), mn=min(mn, in[a]), ++val[a];
val[lca(rev[mx], rev[mn])]-=vect[i].size();
}
dfs2(1, -1);
dfs3(1, -1, 0);
cout<<(ans+1)/2;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |