제출 #330103

#제출 시각아이디문제언어결과실행 시간메모리
330103limabeans수도 (JOI20_capital_city)C++17
100 / 100
782 ms79596 KiB
#include <bits/stdc++.h>
using namespace std;

template<typename T>
void out(T x) { cout << x << endl; exit(0); }
#define watch(x) cout << (#x) << " is " << (x) << endl





using ll = long long;


const int maxn = 1e6 + 5;




int n, k;
vector<int> g[maxn];
int c[maxn];
vector<int> C[maxn];


int ans = 1e9;


bool viz[maxn];



int siz[maxn];


void dfs1(int at, int p) {
    siz[at]=1;
    for (int to: g[at]) {
	if (viz[to]) continue;
	if (to==p) continue;
	dfs1(to,at);
	siz[at]+=siz[to];
    }
}

int findCenter(int at) {
    dfs1(at,-1);
    int S = siz[at];
    while (true) {
	int nxt = -1;
	for (int to: g[at]) {
	    if (viz[to]) continue;
	    if (siz[to]>siz[at]) continue;
	    if (siz[to]*2 >= S) {
		nxt = to;
		break;
	    }
	}
	if (nxt==-1) break;
	at=nxt;
    }

    return at;
}


int par[maxn];
int center[maxn];

void dfs2(int at, int p, int cent) {
    par[at] = p;
    center[at] = cent;
    for (int to: g[at]) {
	if (viz[to]) continue;
	if (to==p) continue;
	dfs2(to,at,cent);
    }
}

bool vizcol[maxn];

void solve(int at) {

    int At = findCenter(at);
    dfs2(At,At,At);



    
    bool ok = true;
    queue<int> qq;
    vector<int> res;
    int cnt = -1;
    qq.push(At);
    
    
    while (ok && !qq.empty()) {
	int at = qq.front();
	++cnt;
	qq.pop();

	while (ok) {
	    if (center[at]==0) break;

	    if (!vizcol[c[at]]) {
		cnt -= C[c[at]].size();
		vizcol[c[at]] = true;
		res.push_back(c[at]);
		for (int to: C[c[at]]) {
		    qq.push(to);
		    if (center[to]!=At) {
			ok = false;
			break;
		    }
		}
	    }
	    
	    center[at]=0;
	    at = par[at];
	    if (at==At) break;
	}
    }

    for (int cc: res) {
	vizcol[cc] = false;
    }

    if (cnt==0) {
	ans = min(ans, (int)res.size());
    }
    
    viz[At] = true;
    
    for (int to: g[At]) {
	if (!viz[to]) {
	    solve(to);
	}
    }
}

int main() {
    ios_base::sync_with_stdio(false); cin.tie(0);  cout.tie(0);

    cin>>n>>k;
    for (int i=0; i<n-1; i++) {
	int u,v; cin>>u>>v;
	g[u].push_back(v);
	g[v].push_back(u);
    }

    for (int i=1; i<=n; i++) {
	cin>>c[i];
	C[c[i]].push_back(i);
    }


    solve(1);

    cout<<ans-1<<endl;    
    return 0;
}
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...