답안 #330853

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
330853 2020-11-26T18:43:01 Z limabeans Mergers (JOI19_mergers) C++17
0 / 100
306 ms 95076 KB
#include <bits/stdc++.h>
using namespace std;

template<typename T>
void out(T x) { cout << x << endl; exit(0); }
#define watch(x) cout << (#x) << " is " << (x) << endl





using ll = long long;

struct dsu0 {
    vector<int> par, siz;
    int n;
    int cc;
    int largest;
    void init(int n) {
	assert(n>0);
	this->n=n;
	cc=n;
	par.resize(n+10);siz.resize(n+10);
	for (int i=0; i<n; i++) par[i]=i,siz[i]=1;
	largest=1;
    }
    int parent(int x) {
	assert(x>=0 && x<n);
	return par[x]==x?x:par[x]=parent(par[x]);
    }
    bool join(int x, int y) {
	x=parent(x);y=parent(y);
	if (x==y) return false;
	cc--;
	if (siz[x]<siz[y]) swap(x,y);
	siz[x]+=siz[y];par[y]=x;
	largest=max(largest,siz[x]);
	return true;
    }
};


const int maxn = 5e5 + 10;


int n, k;

vector<int> g[maxn];
int a[maxn];

const int LOG = 20;

int tin[maxn];
int tout[maxn];
int cloc = 0;
int dep[maxn];
int par[LOG+1][maxn];

int lca(int u, int v) {
    if (dep[u]>dep[v]) swap(u,v);
    // u
    // v

    int dx = dep[v]-dep[u];
    for (int j=LOG-1; j>=0; j--) {
	if (dx>>j&1) {
	    v = par[j][u];
	}
    }

    if (u==v) return v;

    for (int j=LOG-1; j>=0; j--) {
	if (par[j][u]!=par[j][v]) {
	    u=par[j][u];
	    v=par[j][v];
	}
    }

    return par[0][v];
}

vector<int> bycolor[maxn];
vector<int> nodes[maxn];

void dfs(int at, int p) {
    tin[at] = cloc++;
    bycolor[a[at]].push_back(tin[at]);

    for (int j=1; j<LOG; j++) {
	par[j][at] = par[j-1][par[j-1][at]];
    }
    
    for (int to: g[at]) {
	if (to == p) continue;
	par[0][to] = at;
	dep[to] = 1+dep[at];
	dfs(to, at);

    }
    tout[at] = cloc++;
}

dsu0 dsu;

void dfs2(int at, int p, int c) {
    for (int to: g[at]) {
	if (to == p) continue;
	auto iter = lower_bound(bycolor[c].begin(), bycolor[c].end(), tin[to]);
	if (iter != bycolor[c].end() && *iter <= tout[to]) {
	    dsu.join(at, to);
	    //cout<<at+1<<"-->"<<to+1<<endl;
	    dfs2(to, at, c);
	}
    }
}



set<int> G[maxn];

void dfs3(int at, int p) {
    for (int to: g[at]) {
	if (to == p) continue;
	int pat = dsu.parent(at);
	int pto = dsu.parent(to);
	if (pat!=pto) {
	    G[pat].insert(pto);
	    G[pto].insert(pat);
	}
	dfs3(to, at);
    }
}

int main() {
    ios_base::sync_with_stdio(false); cin.tie(0);  cout.tie(0);

    cin>>n>>k;
    for (int i=0; i<n-1; i++) {
	int u,v; cin>>u>>v;
	--u; --v;
	g[u].push_back(v);
	g[v].push_back(u);
    }

    for (int i=0; i<n; i++) {
	cin>>a[i];
	--a[i];
	nodes[a[i]].push_back(i);
    }

    dfs(0,-1);
    dsu.init(n);


    // for each color, join them into a supernode
    for (int j=0; j<k; j++) {
	int mid = nodes[j][0];
	for (int x: nodes[j]) {
	    mid = lca(mid, x);
	}
	
	dfs2(mid,par[0][mid],j);
    }


    set<int> st;
    for (int i=0; i<n; i++) {
    	st.insert(dsu.parent(i));
    }

    if ((int)st.size()==1) out(0);

    dfs3(0,-1);

    int leaves = 0;
    for (int x: st) {
	if ((int)G[x].size() == 1) leaves++;
    }

    int ans = (leaves+1)/2;
    out(ans);
    
    
    
    return 0;
}
# 결과 실행 시간 메모리 Grader output
1 Correct 38 ms 59372 KB Output is correct
2 Incorrect 36 ms 59244 KB Output isn't correct
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 38 ms 59372 KB Output is correct
2 Incorrect 36 ms 59244 KB Output isn't correct
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 38 ms 59372 KB Output is correct
2 Incorrect 36 ms 59244 KB Output isn't correct
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 276 ms 75492 KB Output is correct
2 Correct 306 ms 95076 KB Output is correct
3 Incorrect 41 ms 59884 KB Output isn't correct
4 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 38 ms 59372 KB Output is correct
2 Incorrect 36 ms 59244 KB Output isn't correct
3 Halted 0 ms 0 KB -