제출 #345502

#제출 시각아이디문제언어결과실행 시간메모리
345502Samboor수도 (JOI20_capital_city)C++17
100 / 100
847 ms72300 KiB
#include <bits/stdc++.h>

using namespace std;

typedef long long ll;
typedef long double ld;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
typedef vector<bool> vb;
typedef vector<int> vi;
typedef vector<vector<int>> vvi;
typedef vector<ll> vl;
typedef vector<vector<ll>> vvl;
#define pb push_back
#define eb emplace_back
#define mp make_pair
#define mt make_tuple
#define st first
#define nd second
#define FOR(__VAR, __START, __END) for(int __VAR=__START; __VAR<=__END; __VAR++)
#define FORB(__VAR, __START, __END) for(int __VAR=__START; __VAR>=__END; __VAR--)
#define FORK(__VAR, __START, __END, __DIFF) for(int __VAR=__START; __VAR<=END; __VAR+=__DIFF)
#define all(__VAR) (__VAR).begin(), (__VAR).end()
#define rall(__VAR) (__VAR).rbegin(), (__VAR).rend()
#define DEBUG(__VAR) cout << #__VAR << ": " << __VAR << endl;

template<typename __T1, typename __T2>
ostream & operator<<(ostream &out, pair<__T1, __T2> &__VAR)
{
	cout << "[" << __VAR.st << ", " << __VAR.nd << "]";
	return out;
}

template<typename __T>
ostream & operator<<(ostream &out, vector<__T> &__VAR)
{
	cout << "[";
	FOR(i, 0, (int)__VAR.size()-2)
		cout << __VAR[i] << ", ";
	if(__VAR.size()>0)
		cout << __VAR[__VAR.size()-1];
	cout << "]" << endl;
	
	return out;
}

const int INF=1e9, LOG=20;

int n, k;
vector<int> color;
vector<int> colorLca;
vector<pii> preOrder;
vector<int> depth;
vector<vector<int>> tree, g;
vector<vector<int>> ancestor;
vector<vector<int>> colorInd;
vector<bool> vis;

void dfs1(int cur) {
	static int nextInd=0;
	preOrder[cur].st=preOrder[cur].nd=nextInd++;
	colorInd[color[cur]].pb(preOrder[cur].st);
	vis[cur]=true;
	for(int v:tree[cur])
		if(!vis[v]) {
			ancestor[v][0]=cur;
			depth[v]=depth[cur]+1;
			dfs1(v);
			preOrder[cur].nd=preOrder[v].nd;
		}
}

int lca(int a, int b) {
	if(depth[a]<depth[b])
		swap(a,b);
	int diff=depth[a]-depth[b];
	for(int i=LOG-1; i>=0; i--)
		if((1<<i)<=diff) {
			diff-=1<<i;
			a=ancestor[a][i];
		}
	if(a==b)
		return a;
	for(int i=LOG-1; i>=0; i--)
		if(ancestor[a][i]!=ancestor[b][i]) {
			a=ancestor[a][i];
			b=ancestor[b][i];
		}
	return ancestor[a][0]; 
}

vector<int> comp;
stack<int> s;
vector<vector<int>> ig;
void dfs2(int cur) {
	vis[cur]=true;
	for(int v:g[cur])
		if(!vis[v])
			dfs2(v);
	s.push(cur);
}

void dfs3(int ind, int cur) {
	vis[cur]=true;
	comp[cur]=ind;
	for(int v:ig[cur])
		if(!vis[v])
			dfs3(ind, v);
}

int main()
{
	ios_base::sync_with_stdio(0);
	cin.tie(0);
	cout.tie(0);
	
	cin >> n >> k;
	tree.resize(n);
	color.resize(n);
	preOrder.resize(n);
	colorLca.resize(k);
	colorInd.resize(k);
	g.resize(k);
	vis.resize(n,0);
	depth.resize(n,1);
	ancestor.resize(n, vector<int>(LOG,0));
	for(int i=0; i<n-1; i++) {
		int a,b;
		cin >> a >> b;
		a--,b--;
		tree[a].pb(b);
		tree[b].pb(a);
	}
	for(int i=0; i<n; i++) {
		cin >> color[i];
		color[i]--;
	}
	dfs1(0);
	for(int i=1; i<LOG; i++)
		for(int v=0; v<n; v++)
			ancestor[v][i]=ancestor[ancestor[v][i-1]][i-1];
	for(int i=0; i<k; i++)
		sort(colorInd[i].begin(), colorInd[i].end());
	for(int i=0; i<n; i++)
		colorLca[color[i]]=i;
	for(int i=0; i<n; i++)
		colorLca[color[i]]=lca(colorLca[color[i]], i);
	for(int i=0; i<n; i++) {
		int c=color[i];
		if(depth[colorLca[c]]<depth[i])
			g[c].pb(color[ancestor[i][0]]);
		for(int v:tree[i])
			if(preOrder[v].st>preOrder[i].st) {
				auto it=lower_bound(colorInd[c].begin(), colorInd[c].end(), preOrder[v].st);
				if(it!=colorInd[c].end() && (*it)<=preOrder[v].nd)
					g[c].pb(color[v]);
			}
	}
	comp.resize(k,0);
	vis.resize(k);
	fill(vis.begin(), vis.end(), false);
	for(int i=0; i<k; i++)
		if(!vis[i])
			dfs2(i);
	ig.resize(k);
	for(int i=0; i<k; i++)
		for(int j:g[i])
			ig[j].pb(i);
	int nextInd=0;
	fill(vis.begin(), vis.end(), false);
	while(!s.empty()) {
		int cur=s.top();
		s.pop();
		if(vis[cur])
			continue;
		dfs3(nextInd++, cur);
	}
	vector<int> compSize(nextInd,0);
	vector<bool> valid(nextInd,1);
	for(int i=0; i<k; i++)
		compSize[comp[i]]++;
	for(int i=0; i<k; i++)
		for(int j:g[i])
			if(comp[i]!=comp[j])
				valid[comp[i]]=0;
	int result=INF;
	for(int i=0; i<nextInd; i++)
		if(valid[i])
			result=min(result, compSize[i]);
	cout << result-1 << '\n';
	return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...