Submission #216906

#TimeUsernameProblemLanguageResultExecution timeMemory
216906ToadologistCapital City (JOI20_capital_city)C++17
100 / 100
2636 ms270268 KiB
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef pair<int, int> pii;
#ifdef DEBUG
#define display(x) cerr << #x << " = " << (x) << endl;
#define displaya(a, st, n)\
	{cerr << #a << " = {";\
	for(int qwq = (st); qwq <= (n); ++qwq) {\
		if(qwq == (st)) cerr << ((a)[qwq]);\
		else cerr << ", " << ((a)[qwq]);\
	} cerr << "}" << endl;}
#define displayv(v) displaya(v, 0, (int)(v).size() - 1)
#define eprintf(...) fprintf(stderr, __VA_ARGS__)
#else
#define display(x) ;
#define displaya(a, st, n) ;
#define displayv(v) ;
#define eprintf(...) if(0) fprintf(stderr, "...")
#endif
template<typename T> bool chmin(T &a, const T &b) { return a > b ? a = b, true : false; }
template<typename T> bool chmax(T &a, const T &b) { return a < b ? a = b, true : false; }
template<typename A, typename B>
ostream& operator << (ostream& out, const pair<A, B> &p) {
	return out << '(' << p.first << ", " << p.second << ')';
}
#ifndef LOCAL
char pool[1<<15|1],*it=pool+32768;
#define getchar() (it>=pool+32768?(pool[fread(pool,sizeof(char),\
	1<<15,stdin)]=EOF,*((it=pool)++)):*(it++))
#endif
inline int readint() {
	int a = 0; char c = getchar(), p = 0;
	while(isspace(c)) c = getchar();
	if(c == '-') p = 1, c = getchar();
	while(isdigit(c)) a = a*10 + c - '0', c = getchar();
	return p ? -a : a;
}

const int maxN = 200000 + 5;
int n, k, g[maxN];
vector<int> G[maxN];
vector<int> fam[maxN];
int dep[maxN], f[20][maxN];
int pre[maxN], dfs_clock = 0;

void dfs(int u, int fa) {
	pre[u] = ++dfs_clock;
	for(int v : G[u]) if(v != fa) {
		dep[v] = dep[u] + 1;
		f[0][v] = u;
		dfs(v, u);
	}
}
int lca(int x, int y) {
	if(dep[x] < dep[y]) swap(x, y);
	int delta = dep[x] - dep[y];
	for(int i = 0; i < 20; ++i) if(delta >> i & 1) x = f[i][x];
	for(int i = 19; i >= 0; --i) if(f[i][x] != f[i][y]) x = f[i][x], y = f[i][y];
	return x == y ? x : f[0][x];
}

const int maxM = 200000 * 22 + 5;
const int maxE = 200000 * 22 * 2 + 200000 * 20;
// 20*n + k
int to[maxE], last[maxE], h[maxM], cm = 0;
int ito[maxE], ilast[maxE], ih[maxM];
void link(int x, int y) {
	cm++;
	assert(cm < maxE - 10);
	assert(x < maxM && y < maxM);
	to[cm] = y; last[cm] = h[x]; h[x] = cm;
	ito[cm] = x; ilast[cm] = ih[y]; ih[y] = cm;
}

int encode(int k, int u) {
	return k * n + u;
}

bool vis[maxM];
vector<int> stk;
void dfs1(int u) {
	vis[u] = true;
	for(int i = h[u]; i; i = last[i]) if(!vis[to[i]]) dfs1(to[i]);
//	for(int v : H[u]) if(!vis[v]) dfs1(v);
	stk.push_back(u);
}
int cnt = 0, scc[maxM];
int now[maxM], len = 0;
void dfs2(int u) {
	vis[u] = true; scc[u] = cnt; now[len++] = u;
//	for(int v : iH[u]) if(!vis[v]) dfs2(v);
	for(int i = ih[u]; i; i = ilast[i])
		if(!vis[ito[i]]) dfs2(ito[i]);
}
int solve() {
	int ans = k;
	memset(vis, 0, sizeof(vis));
	for(int u = 1; u <= 20 * n + k; ++u) if(!vis[u]) dfs1(u);
	memset(vis, 0, sizeof(vis));
	while(stk.size()) {
		int u = stk.back(); stk.pop_back();
		if(!vis[u]) {
			++cnt; len = 0;
			dfs2(u);
			int res = 0;
			bool ok = true;
			for(int j = 0; j < len; ++j) {
				int x = now[j];
				res += (x > 20 * n);
				for(int i = h[x]; i; i = last[i])
					ok &= (scc[x] == scc[to[i]]);
			}
//			for(int x : now) for(int y : H[x]) ok &= (scc[x] == scc[y]);
			if(ok && res) chmin(ans, res);
		}
	}
	return ans - 1;
}

int main() {
//	freopen("qwq.txt", "r", stdin);
	n = readint(); k = readint();
	for(int i = 0; i < n - 1; ++i) {
		int x = readint(), y = readint();
		G[x].push_back(y);
		G[y].push_back(x);
	}
	for(int i = 1; i <= n; ++i) g[i] = readint(), fam[g[i]].push_back(i);
	dep[1] = 1; f[0][1] = 0;
	dfs(1, -1);
	for(int k = 1; k < 20; ++k)
		for(int u = 1; u <= n; ++u)
			f[k][u] = f[k - 1][f[k - 1][u]];
	for(int u = 1; u <= n; ++u)
		link(encode(0, u), g[u] + 20 * n);
	for(int k = 1; k < 20; ++k)
		for(int u = 1; u <= n; ++u) if(dep[u] >= (1 << k)) {
			link(encode(k, u), encode(k - 1, u));
			link(encode(k, u), encode(k - 1, f[k - 1][u]));
		}
	for(int t = 1; t <= k; ++t) {
		sort(fam[t].begin(), fam[t].end(), [&](int x, int y) {
			return pre[x] < pre[y];
		});
		for(int i = 0; i + 1 < (int)fam[t].size(); ++i) {
			int u = fam[t][i], v = fam[t][i + 1];
			int w = lca(u, v);
			{
				int delta = dep[u] - dep[w];
				int x = u;
				for(int k = 0; k < 20; ++k) if(delta >> k & 1)
					link(g[u] + 20 * n, encode(k, x)),
					x = f[k][x];
			}
			{
				int delta = dep[v] - dep[w] + 1;
				int x = v;
				for(int k = 0; k < 20; ++k) if(delta >> k & 1)
					link(g[v] + 20 * n, encode(k, x)),
					x = f[k][x];
			}
		}
	}
	cout << solve() << endl;
	return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...