답안 #156254

# 제출 시각 아이디 문제 언어 결과 실행 시간 메모리
156254 2019-10-04T16:24:27 Z popovicirobert Mergers (JOI19_mergers) C++14
0 / 100
114 ms 26056 KB
#include <bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define lsb(x) (x & (-x)) 
    
using namespace std;

const int MAXN = (int) 5e5;
    
struct DSU {
	vector <int> par;
	int n;

	inline void init(int _n) {
		n = _n;
		par.resize(n + 1);
	}

	int root(int x) {
		if(par[x] == 0) return x;
		return par[x] = root(par[x]);
	}

	inline void join(int x, int y) {
		x = root(x), y = root(y);
		if(x != y) {
			par[x] = y;
		}
	}
};

vector <int> g[MAXN + 1];
int idl[MAXN + 1], idr[MAXN + 1], sz;
int anc[MAXN + 1][20], lvl[MAXN + 1];

void dfs(int nod) {
	idl[nod] = ++sz;
	for(auto it : g[nod]) {
		if(lvl[it] == 0) {
			anc[it][0] = nod;
			lvl[it] = lvl[nod] + 1;
			dfs(it);
		}
	}
	idr[nod] = sz;
}
   
int main() {
    //ifstream cin("A.in");
    //ofstream cout("A.out");
    int i, n, k;
    ios::sync_with_stdio(false);
    cin.tie(0), cout.tie(0);
		
	cin >> n >> k;

	for(i = 1; i < n; i++) {
		int x, y;
		cin >> x >> y;
		g[x].push_back(y);
		g[y].push_back(x);
	}

	vector < vector <int> > col(k + 1);
	for(i = 1; i <= n; i++) {
		int c;
		cin >> c;
		col[c].push_back(i);
	}

	lvl[1] = 1;
	dfs(1);

	for(int bit = 1; bit < 20; bit++) {
		for(i = 1; i <= n; i++) {
			anc[i][bit] = anc[anc[i][bit - 1]][bit - 1];
		}
	}

	auto get_lca = [&](int x, int y) {
		if(lvl[x] < lvl[y]) {
			swap(x, y);
		}
		int dst = lvl[x] - lvl[y];
		for(int bit = 19; bit >= 0; bit--) {
			if(dst & (1 << bit)) {
				x = anc[x][bit];
			}
		}
		if(x == y) return x;
		for(int bit = 19; bit >= 0; bit--) {
			if(anc[x][bit] == anc[y][bit]) {
				x = anc[x][bit], y = anc[y][bit];
			}
		}
		return anc[x][0];
	};

	auto in = [&](int x, int y) -> bool {
		return idl[x] <= idl[y] && idr[y] <= idr[x];
	};

	DSU dsu; dsu.init(n);

	for(int c = 1; c <= k; c++) {
		sort(col[c].begin(), col[c].end(), [&](const int &x, const int &y) {
				return idl[x] < idl[y];
			});

		vector <int> nodes = col[c];
		for(i = 0; i + 1 < col[c].size(); i++) {
			nodes.push_back(get_lca(nodes[i], nodes[i + 1]));
		}
		
		sort(nodes.begin(), nodes.end(), [&](const int &x, const int &y) {
				return idl[x] < idl[y];
			});
		nodes.resize(unique(nodes.begin(), nodes.end()) - nodes.begin());

		stack <int> stk;	
		for(auto nod : nodes) {
			while(stk.size() && in(stk.top(), nod) == 0) {
				stk.pop();
			}
			if(stk.size()) {
				int cur = nod;
				while(cur > 1 && lvl[cur] >= lvl[stk.top()]) {
					dsu.join(cur, anc[cur][0]);
					cur = dsu.root(cur);
				}
			}
			stk.push(nod);
		}
	}

	vector <int> deg(n + 1);

	for(i = 1; i <= n; i++) {
		int x = dsu.root(i);
		for(auto it : g[i]) {
			int y = dsu.root(it);
			if(x != y) {
				deg[x]++;
			}
		}
	}

	int ans = 0;
	for(i = 1; i <= n; i++) {
		if(dsu.root(i) == i && deg[i] == 1) {
			ans++;
		}
	}

	cout << (ans + 1) / 2;
            
	return 0;
}

Compilation message

mergers.cpp: In function 'int main()':
mergers.cpp:111:20: warning: comparison between signed and unsigned integer expressions [-Wsign-compare]
   for(i = 0; i + 1 < col[c].size(); i++) {
              ~~~~~~^~~~~~~~~~~~~~~
# 결과 실행 시간 메모리 Grader output
1 Correct 12 ms 12024 KB Output is correct
2 Incorrect 13 ms 12152 KB Output isn't correct
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 12 ms 12024 KB Output is correct
2 Incorrect 13 ms 12152 KB Output isn't correct
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 12 ms 12024 KB Output is correct
2 Incorrect 13 ms 12152 KB Output isn't correct
3 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Incorrect 114 ms 26056 KB Output isn't correct
2 Halted 0 ms 0 KB -
# 결과 실행 시간 메모리 Grader output
1 Correct 12 ms 12024 KB Output is correct
2 Incorrect 13 ms 12152 KB Output isn't correct
3 Halted 0 ms 0 KB -