# | Time | Username | Problem | Language | Result | Execution time | Memory |
---|---|---|---|---|---|---|---|
263370 | mjkocijan | Mergers (JOI19_mergers) | C++14 | 192 ms | 75092 KiB |
This submission is migrated from previous version of oj.uz, which used different machine for grading. This submission may have different result if resubmitted.
#include <bits/stdc++.h>
using namespace std;
#define X first
#define Y second
#define pb push_back
typedef long long ll;
typedef pair<ll, ll> ii;
#define MAXN 500500
#define LOG 20
int n, k;
vector<int> g[MAXN];
int s[MAXN], rt[MAXN], sz[MAXN], dep[MAXN], st1[MAXN], st2[MAXN];
vector<int> sr[MAXN];
int anc[LOG][MAXN];
int dfs(int cv, int rod, int dub = 0)
{
dep[cv] = dub;
anc[0][cv] = rod;
st1[cv] = 1;
for (int i: g[cv]) {
if (rod != i) {
st1[cv] += dfs(i, cv, dub + 1);
}
}
return st1[cv];
}
int lca(int x, int y)
{
if (dep[x] > dep[y]) swap(x, y);
for (int i = LOG - 1; i >= 0; i--) {
int novi = anc[i][y];
if (dep[novi] >= dep[x])
y = novi;
}
if (x == y) return x;
for (int i = LOG - 1; i >= 0; i--) {
if (anc[i][x] != anc[i][y]) {
x = anc[i][x];
y = anc[i][y];
}
}
return anc[0][x];
}
int fn(int x)
{
if (x == rt[x]) return x;
return rt[x] = fn(rt[x]);
}
void merg(int x, int y)
{
x = fn(x);
y = fn(y);
if (x == y) return;
if (dep[x] < dep[y]) return;
//if (sz[x] > sz[y]) swap(x, y);
rt[x] = y;
//sz[y] += sz[x] == sz[y];
st2[y] += st2[x];
}
//set<int> s;
set<int> sus[MAXN];
int main()
{
scanf("%d%d", &n, &k);
for (int i = 0; i < n - 1; i++) {
int q, w;
scanf("%d%d", &q, &w);
q--; w--;
g[q].pb(w);
g[w].pb(q);
}
for (int i = 0; i < n; i++) {
rt[i] = i;
st2[i] = 1;
scanf("%d", &s[i]);
s[i]--;
sr[s[i]].pb(i);
}
dfs(0, 0);
for (int i = 1; i < LOG; i++) {
for (int j = 0; j < n; j++) {
anc[i][j] = anc[i - 1][anc[i - 1][j]];
}
}
for (int i = 0; i < k; i++) {
int lcaa = sr[i][0];
for (int j = 1; j < sr[i].size(); j++) {
lcaa = lca(lcaa, sr[i][j]);
}
for (int j: sr[i]) {
j = fn(j);
while (dep[j] > dep[lcaa]) {
merg(j, anc[0][j]);
j = fn(j);
}
}
}
int reza = 0;
for (int i = 0; i < n; i++) {
for (int j: g[i]) {
if (fn(i) != fn(j))
sus[fn(i)].insert(fn(j));
}
//cout << i+1<<' '<<fn(i)+1<<' '<<st1[i]<<' '<<st2[i]<<endl;
//if (i == fn(i) && st1[i] == st2[i]) reza++;
}
for (int i = 0; i < k; i++)
if (sus[i].size() == 1)
reza++;
/*if (reza == 1)
printf("0\n");
else*/
printf("%d\n", (reza + 1) / 2);
return 0;
}
Compilation message (stderr)
# | Verdict | Execution time | Memory | Grader output |
---|---|---|---|---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|---|---|---|---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|---|---|---|---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|---|---|---|---|
Fetching results... |
# | Verdict | Execution time | Memory | Grader output |
---|---|---|---|---|
Fetching results... |