// oooo
/*
har chi delet mikhad bebar ~
gitar o ba khodet nabar! ~
;Amoo_Hasan;
*/
#include<bits/stdc++.h>
//#pragma GCC optimize("O3,no-stack-protector,unroll-loops")
//#pragma GCC target("avx2,fma")
using namespace std;
typedef long long ll;
typedef long double ld;
#define Sz(x) int((x).size())
#define All(x) (x).begin(), (x).end()
#define wtf(x) cout<<#x <<" : " <<x <<endl
#define mak make_pair
//constexpr int PRI = 1000696969;
constexpr ll INF = 1e18, N = 3e6 + 10, MOD = 1e9 + 7, LOG = 20;
int n, k;
int s[N], sub[N], head[N], link[N], fin[N];
int par[N][LOG], h[N], st[N], fn[N], tim;
int sz[N], pv[N];
int ans, total;
int seg[N];
vector<int> vc[N], adj[N], nei[N];
void dfs(int x, int p = -1) {
par[x][0] = p;
for(int i = 1; i < LOG; i++) par[x][i] = par[par[x][i - 1]][i - 1];
h[x] = h[p] + 1;
sub[x] = 1;
for(auto j : adj[x]) {
if(j == p) continue;
dfs(j, x);
sub[x] += sub[j];
}
}
void hld(int x, int p = -1, int hi = -1) {
link[tim] = x;
st[x] = tim++;
head[x] = hi;
int bz = -1;
for(auto j : adj[x]) {
if(j == p) continue;
if(bz == -1 || sub[bz] < sub[j]) bz = j;
}
if(bz == -1) {
fn[x] = tim;
return;
}
hld(bz, x, hi);
for(auto j : adj[x]) {
if(j == p || j == bz) continue;
hld(j, x, j);
}
fn[x] = tim;
}
int get_par(int x, int y) {
for(int i = 0; i < LOG; i++)
if((y >> i) & 1)
x = par[x][i];
return x;
}
int lca(int x, int y) {
if(h[x] > h[y]) swap(x, y);
y = get_par(y, h[y] - h[x]);
if(x == y) return x;
for(int i = LOG - 1; i >= 0; i--)
if(par[x][i] != par[y][i])
x = par[x][i], y = par[y][i];
return par[x][0];
}
int find(int x) {
if(pv[x] == x) return x;
return pv[x] = find(pv[x]);
}
void merge(int x, int y) {
// cout<<"^^" <<x <<' ' <<y <<endl;
int X = find(x), Y = find(y);
if(X == Y) return;
if(sz[X] < sz[Y]) swap(X, Y);
pv[Y] = X, sz[X] += sz[Y];
}
void build(int l = 0, int r = n, int v = 1) {
seg[v] = -1;
if(r - l < 2) {
fin[link[l]] = v;
return;
}
int mid = (l + r) >> 1;
build(l, mid, 2 * v), build(mid, r, 2 * v + 1);
}
void upd(int s, int e, int val, int l = 0, int r = n, int v = 1) {
if(r <= s || l >= e) return;
if(l >= s && r <= e) {
if(seg[v] == -1) seg[v] = val;
else merge(seg[v], val);
return;
}
int mid = (l + r) >> 1;
upd(s, e, val, l, mid, 2 * v), upd(s, e, val, mid, r, 2 * v + 1);
}
void relax(int l = 0, int r = n, int v = 1) {
if(r - l < 2) {
if(seg[v] != -1) merge(seg[v], s[link[l]]);
return;
}
int mid = (l + r) >> 1;
if(seg[v] != -1) {
if(seg[2 * v] == -1) seg[2 * v] = seg[v];
else merge(seg[v], seg[2 * v]);
if(seg[2 * v + 1] == -1) seg[2 * v + 1] = seg[v];
else merge(seg[v], seg[2 * v + 1]);
}
relax(l, mid, 2 * v), relax(mid, r, 2 * v + 1);
}
bool cmp(int i, int j) {
return st[i] < st[j];
}
void solve(int x) {
vector<int> ver;
for(auto i : vc[x]) ver.push_back(i);
sort(All(ver), cmp);
int sz = Sz(ver);
for(int i = 1; i < sz; i++) {
ver.push_back(lca(ver[i - 1], ver[i]));
}
sort(All(ver), cmp);
ver.erase(unique(All(ver)), ver.end());
stack<int> mt;
mt.push(ver[0]);
merge(fin[ver[0]], x);
for(int i = 1; i < Sz(ver); i++) {
int v = ver[i];
while(fn[mt.top()] < fn[v]) mt.pop();
int p = mt.top();
merge(x, fin[v]);
while(v != -1 && h[v] >= h[p]) {
if(h[head[v]] < h[p]) break;
upd(st[head[v]], st[v], x);
v = head[v];
}
if(v != -1)
upd(st[p], st[v], x);
mt.push(v);
}
}
int main() {
ios :: sync_with_stdio(0), cin.tie(0); cout.tie(0);
for(int i = 0; i < N; i++) sz[i] = 1, pv[i] = i;
cin >>n >>k;
for(int i = 0; i < n - 1; i++) {
int u, v; cin >>u >>v;
--u, --v;
adj[u].push_back(v), adj[v].push_back(u);
}
dfs(0);
hld(0);
total = 2000005;
for(int i = 0; i < n; i++) {
cin >>s[i];
--s[i], s[i] += total;
vc[s[i]].push_back(i);
}
build();
for(int i = total; i < total + k; i++) {
solve(i);
}
relax();
for(int i = 0; i < N; i++) pv[i] = find(pv[i]);
for(int i = 0; i < n; i++) {
int u = fin[i];
for(auto j : adj[i]) {
int v = fin[j];
if(pv[u] == pv[v]) continue;
nei[pv[u]].push_back(pv[v]);
nei[pv[v]].push_back(pv[u]);
}
}
for(int i = 0; i < N; i++) {
nei[i].erase(unique(All(nei[i])), nei[i].end());
if(Sz(nei[i]) == 1) ans++;
}
cout<<(ans + 1) / 2;
return 0;
}
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Execution timed out |
3023 ms |
235212 KB |
Time limit exceeded |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Execution timed out |
3023 ms |
235212 KB |
Time limit exceeded |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Execution timed out |
3023 ms |
235212 KB |
Time limit exceeded |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Execution timed out |
3068 ms |
251332 KB |
Time limit exceeded |
2 |
Halted |
0 ms |
0 KB |
- |
# |
결과 |
실행 시간 |
메모리 |
Grader output |
1 |
Execution timed out |
3023 ms |
235212 KB |
Time limit exceeded |
2 |
Halted |
0 ms |
0 KB |
- |