This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef pair<int, int> pii;
#ifdef DEBUG
#define display(x) cerr << #x << " = " << (x) << endl;
#define displaya(a, st, n)\
{cerr << #a << " = {";\
for(int qwq = (st); qwq <= (n); ++qwq) {\
if(qwq == (st)) cerr << ((a)[qwq]);\
else cerr << ", " << ((a)[qwq]);\
} cerr << "}" << endl;}
#define displayv(v) displaya(v, 0, (int)(v).size() - 1)
#define eprintf(...) fprintf(stderr, __VA_ARGS__)
#else
#define display(x) ;
#define displaya(a, st, n) ;
#define displayv(v) ;
#define eprintf(...) if(0) fprintf(stderr, "...")
#endif
template<typename T> bool chmin(T &a, const T &b) { return a > b ? a = b, true : false; }
template<typename T> bool chmax(T &a, const T &b) { return a < b ? a = b, true : false; }
template<typename A, typename B>
ostream& operator << (ostream& out, const pair<A, B> &p) {
return out << '(' << p.first << ", " << p.second << ')';
}
#ifndef LOCAL
char pool[1<<15|1],*it=pool+32768;
#define getchar() (it>=pool+32768?(pool[fread(pool,sizeof(char),\
1<<15,stdin)]=EOF,*((it=pool)++)):*(it++))
#endif
inline int readint() {
int a = 0; char c = getchar(), p = 0;
while(isspace(c)) c = getchar();
if(c == '-') p = 1, c = getchar();
while(isdigit(c)) a = a*10 + c - '0', c = getchar();
return p ? -a : a;
}
const int maxN = 200000 + 5;
int n, k, g[maxN];
vector<int> G[maxN];
vector<int> fam[maxN];
int dep[maxN], f[20][maxN];
int pre[maxN], dfs_clock = 0;
void dfs(int u, int fa) {
pre[u] = ++dfs_clock;
for(int v : G[u]) if(v != fa) {
dep[v] = dep[u] + 1;
f[0][v] = u;
dfs(v, u);
}
}
int lca(int x, int y) {
if(dep[x] < dep[y]) swap(x, y);
int delta = dep[x] - dep[y];
for(int i = 0; i < 20; ++i) if(delta >> i & 1) x = f[i][x];
for(int i = 19; i >= 0; --i) if(f[i][x] != f[i][y]) x = f[i][x], y = f[i][y];
return x == y ? x : f[0][x];
}
const int maxM = 200000 * 22 + 5;
const int maxE = 200000 * 22 * 2 + 200000 * 20;
// 20*n + k
int to[maxE], last[maxE], h[maxM], cm = 0;
int ito[maxE], ilast[maxE], ih[maxM];
void link(int x, int y) {
cm++;
assert(cm < maxE - 10);
assert(x < maxM && y < maxM);
to[cm] = y; last[cm] = h[x]; h[x] = cm;
ito[cm] = x; ilast[cm] = ih[y]; ih[y] = cm;
}
int encode(int k, int u) {
return k * n + u;
}
bool vis[maxM];
vector<int> stk;
void dfs1(int u) {
vis[u] = true;
for(int i = h[u]; i; i = last[i]) if(!vis[to[i]]) dfs1(to[i]);
// for(int v : H[u]) if(!vis[v]) dfs1(v);
stk.push_back(u);
}
int cnt = 0, scc[maxM];
int now[maxM], len = 0;
void dfs2(int u) {
vis[u] = true; scc[u] = cnt; now[len++] = u;
// for(int v : iH[u]) if(!vis[v]) dfs2(v);
for(int i = ih[u]; i; i = ilast[i])
if(!vis[ito[i]]) dfs2(ito[i]);
}
int solve() {
int ans = k;
memset(vis, 0, sizeof(vis));
for(int u = 1; u <= 20 * n + k; ++u) if(!vis[u]) dfs1(u);
memset(vis, 0, sizeof(vis));
while(stk.size()) {
int u = stk.back(); stk.pop_back();
if(!vis[u]) {
++cnt; len = 0;
dfs2(u);
int res = 0;
bool ok = true;
for(int j = 0; j < len; ++j) {
int x = now[j];
res += (x > 20 * n);
for(int i = h[x]; i; i = last[i])
ok &= (scc[x] == scc[to[i]]);
}
// for(int x : now) for(int y : H[x]) ok &= (scc[x] == scc[y]);
if(ok && res) chmin(ans, res);
}
}
return ans - 1;
}
int main() {
// freopen("qwq.txt", "r", stdin);
n = readint(); k = readint();
for(int i = 0; i < n - 1; ++i) {
int x = readint(), y = readint();
G[x].push_back(y);
G[y].push_back(x);
}
for(int i = 1; i <= n; ++i) g[i] = readint(), fam[g[i]].push_back(i);
dep[1] = 1; f[0][1] = 0;
dfs(1, -1);
for(int k = 1; k < 20; ++k)
for(int u = 1; u <= n; ++u)
f[k][u] = f[k - 1][f[k - 1][u]];
for(int u = 1; u <= n; ++u)
link(encode(0, u), g[u] + 20 * n);
for(int k = 1; k < 20; ++k)
for(int u = 1; u <= n; ++u) if(dep[u] >= (1 << k)) {
link(encode(k, u), encode(k - 1, u));
link(encode(k, u), encode(k - 1, f[k - 1][u]));
}
for(int t = 1; t <= k; ++t) {
sort(fam[t].begin(), fam[t].end(), [&](int x, int y) {
return pre[x] < pre[y];
});
for(int i = 0; i + 1 < (int)fam[t].size(); ++i) {
int u = fam[t][i], v = fam[t][i + 1];
int w = lca(u, v);
{
int delta = dep[u] - dep[w];
int x = u;
for(int k = 0; k < 20; ++k) if(delta >> k & 1)
link(g[u] + 20 * n, encode(k, x)),
x = f[k][x];
}
{
int delta = dep[v] - dep[w] + 1;
int x = v;
for(int k = 0; k < 20; ++k) if(delta >> k & 1)
link(g[v] + 20 * n, encode(k, x)),
x = f[k][x];
}
}
}
cout << solve() << endl;
return 0;
}
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|
Fetching results... |